From 41f6daf36fd8ae0615c018bfd495492e62c5c5d4 Mon Sep 17 00:00:00 2001 From: loneylee Date: Mon, 1 Jul 2024 11:29:25 +0800 Subject: [PATCH 1/6] add bnlj --- .../backendsapi/clickhouse/CHMetricsApi.scala | 33 +- .../clickhouse/CHSparkPlanExecApi.scala | 22 +- ...oadcastNestedLoopJoinExecTransformer.scala | 143 +++++ ...roadcastNestedLoopJoinMetricsUpdater.scala | 123 ++++ .../GlutenClickHouseTPCDSAbstractSuite.scala | 7 +- cpp-ch/local-engine/Common/CHUtil.cpp | 21 + cpp-ch/local-engine/Common/CHUtil.h | 6 + .../Join/BroadCastJoinBuilder.cpp | 57 +- .../Join/StorageJoinFromReadBuffer.cpp | 85 +-- .../Join/StorageJoinFromReadBuffer.h | 6 +- cpp-ch/local-engine/Parser/CrossRelParser.cpp | 573 ++++++++++++++++++ cpp-ch/local-engine/Parser/CrossRelParser.h | 67 ++ cpp-ch/local-engine/Parser/RelParser.cpp | 2 + .../Parser/SerializedPlanParser.cpp | 1 + .../Parser/SerializedPlanParser.h | 1 + .../substrait/proto/substrait/algebra.proto | 1 + .../backendsapi/BackendSettingsApi.scala | 2 +- ...oadcastNestedLoopJoinExecTransformer.scala | 23 +- .../apache/gluten/execution/JoinUtils.scala | 12 +- .../columnar/MiscColumnarRules.scala | 22 + .../columnar/heuristic/HeuristicApplier.scala | 3 +- 21 files changed, 1112 insertions(+), 98 deletions(-) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/metrics/BroadcastNestedLoopJoinMetricsUpdater.scala create mode 100644 cpp-ch/local-engine/Parser/CrossRelParser.cpp create mode 100644 cpp-ch/local-engine/Parser/CrossRelParser.h diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala index a5fb4a1853e8..59c38b02194a 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala @@ -348,16 +348,33 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil { metrics: Map[String, SQLMetric]): MetricsUpdater = new HashJoinMetricsUpdater(metrics) override def genNestedLoopJoinTransformerMetrics( - sparkContext: SparkContext): Map[String, SQLMetric] = { - throw new UnsupportedOperationException( - s"NestedLoopJoinTransformer metrics update is not supported in CH backend") - } + sparkContext: SparkContext): Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "outputVectors" -> SQLMetrics.createMetric(sparkContext, "number of output vectors"), + "outputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of output bytes"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "inputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of input bytes"), + "extraTime" -> SQLMetrics.createTimingMetric(sparkContext, "extra operators time"), + "inputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of waiting for data"), + "outputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of waiting for output"), + "streamPreProjectionTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time of stream side preProjection"), + "buildPreProjectionTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time of build side preProjection"), + "postProjectTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time of postProjection"), + "probeTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time of probe"), + "totalTime" -> SQLMetrics.createTimingMetric(sparkContext, "time"), + "fillingRightJoinSideTime" -> SQLMetrics.createTimingMetric( + sparkContext, + "filling right join side time"), + "conditionTime" -> SQLMetrics.createTimingMetric(sparkContext, "join condition time") + ) override def genNestedLoopJoinTransformerMetricsUpdater( - metrics: Map[String, SQLMetric]): MetricsUpdater = { - throw new UnsupportedOperationException( - s"NestedLoopJoinTransformer metrics update is not supported in CH backend") - } + metrics: Map[String, SQLMetric]): MetricsUpdater = new BroadcastNestedLoopJoinMetricsUpdater( + metrics) override def genSampleTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = { throw new UnsupportedOperationException( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index f5feade886b9..fc2ebda397e6 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -373,8 +373,13 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { buildSide: BuildSide, joinType: JoinType, condition: Option[Expression]): BroadcastNestedLoopJoinExecTransformer = - throw new GlutenNotSupportException( - "BroadcastNestedLoopJoinExecTransformer is not supported in ch backend.") + CHBroadcastNestedLoopJoinExecTransformer( + left, + right, + buildSide, + joinType, + condition + ) override def genSampleExecTransformer( lowerBound: Double, @@ -460,16 +465,23 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { child: SparkPlan, numOutputRows: SQLMetric, dataSize: SQLMetric): BuildSideRelation = { - val hashedRelationBroadcastMode = mode.asInstanceOf[HashedRelationBroadcastMode] + + val buildKeys: Seq[Expression] = mode match { + case mode1: HashedRelationBroadcastMode => + mode1.key + case _ => + // IdentityBroadcastMode + Seq.empty + } + val (newChild, newOutput, newBuildKeys) = if ( - hashedRelationBroadcastMode.key + buildKeys .forall(k => k.isInstanceOf[AttributeReference] || k.isInstanceOf[BoundReference]) ) { (child, child.output, Seq.empty[Expression]) } else { // pre projection in case of expression join keys - val buildKeys = hashedRelationBroadcastMode.key val appendedProjections = new ArrayBuffer[NamedExpression]() val preProjectionBuildKeys = buildKeys.zipWithIndex.map { case (e, idx) => diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala new file mode 100644 index 000000000000..55f86bb32cce --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.gluten.backendsapi.BackendsApiManager + +import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.GlutenDriverEndpoint +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashJoin} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch + +import com.google.protobuf.{Any, StringValue} + +case class CHBroadcastNestedLoopJoinExecTransformer( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) + extends BroadcastNestedLoopJoinExecTransformer( + left, + right, + buildSide, + joinType, + condition + ) { + // Unique ID for builded table + lazy val buildBroadcastTableId: String = "BuiltBroadcastTable-" + buildPlan.id + + lazy val (buildKeyExprs, streamedKeyExprs) = { + require( + leftKeys.length == rightKeys.length && + leftKeys + .map(_.dataType) + .zip(rightKeys.map(_.dataType)) + .forall(types => sameType(types._1, types._2)), + "Join keys from two sides should have same length and types" + ) + // Spark has an improvement which would patch integer joins keys to a Long value. + // But this improvement would cause add extra project before hash join in velox, + // disabling this improvement as below would help reduce the project. + val (lkeys, rkeys) = if (BackendsApiManager.getSettings.enableJoinKeysRewrite()) { + (HashJoin.rewriteKeyExpr(leftKeys), HashJoin.rewriteKeyExpr(rightKeys)) + } else { + (leftKeys, rightKeys) + } + if (needSwitchChildren) { + (lkeys, rkeys) + } else { + (rkeys, lkeys) + } + } + + override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = { + val streamedRDD = getColumnarInputRDDs(streamedPlan) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + if (executionId != null) { + GlutenDriverEndpoint.collectResources(executionId, buildBroadcastTableId) + } else { + logWarning( + s"Can't not trace broadcast table data $buildBroadcastTableId" + + s" because execution id is null." + + s" Will clean up until expire time.") + } + val broadcast = buildPlan.executeBroadcast[BuildSideRelation]() + val context = + BroadCastHashJoinContext(Seq.empty, joinType, false, buildPlan.output, buildBroadcastTableId) + val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context) + // FIXME: Do we have to make build side a RDD? + streamedRDD :+ broadcastRDD + } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): CHBroadcastNestedLoopJoinExecTransformer = + copy(left = newLeft, right = newRight) + + def isMixedCondition(cond: Option[Expression]): Boolean = { + val res = if (cond.isDefined) { + val leftOutputSet = left.outputSet + val rightOutputSet = right.outputSet + val allReferences = cond.get.references + !(allReferences.subsetOf(leftOutputSet) || allReferences.subsetOf(rightOutputSet)) + } else { + false + } + res + } + + def sameType(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (ArrayType(fromElement, _), ArrayType(toElement, _)) => + sameType(fromElement, toElement) + + case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => + sameType(fromKey, toKey) && + sameType(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (l, r) => + l.name.equalsIgnoreCase(r.name) && + sameType(l.dataType, r.dataType) + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } + + override def genJoinParameters(): Any = { + val joinParametersStr = new StringBuffer("JoinParameters:") + joinParametersStr + .append("buildHashTableId=") + .append(buildBroadcastTableId) + .append("\n") + val message = StringValue + .newBuilder() + .setValue(joinParametersStr.toString) + .build() + BackendsApiManager.getTransformerApiInstance.packPBMessage(message) + } + +} diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/BroadcastNestedLoopJoinMetricsUpdater.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/BroadcastNestedLoopJoinMetricsUpdater.scala new file mode 100644 index 000000000000..2f8875cf9b74 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/BroadcastNestedLoopJoinMetricsUpdater.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.metrics + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.metric.SQLMetric + +class BroadcastNestedLoopJoinMetricsUpdater(val metrics: Map[String, SQLMetric]) + extends MetricsUpdater + with Logging { + + override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = { + try { + if (opMetrics != null) { + val operatorMetrics = opMetrics.asInstanceOf[OperatorMetrics] + if (!operatorMetrics.metricsList.isEmpty && operatorMetrics.joinParams != null) { + val joinParams = operatorMetrics.joinParams + var currentIdx = operatorMetrics.metricsList.size() - 1 + var totalTime = 0L + + // build side pre projection + if (joinParams.buildPreProjectionNeeded) { + metrics("buildPreProjectionTime") += + (operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong + metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors + totalTime += operatorMetrics.metricsList.get(currentIdx).time + currentIdx -= 1 + } + + // stream side pre projection + if (joinParams.streamPreProjectionNeeded) { + metrics("streamPreProjectionTime") += + (operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong + metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors + totalTime += operatorMetrics.metricsList.get(currentIdx).time + currentIdx -= 1 + } + + // update fillingRightJoinSideTime + MetricsUtil + .getAllProcessorList(operatorMetrics.metricsList.get(currentIdx)) + .foreach( + processor => { + if (processor.name.equalsIgnoreCase("FillingRightJoinSide")) { + metrics("fillingRightJoinSideTime") += (processor.time / 1000L).toLong + } + }) + + // joining + val joinMetricsData = operatorMetrics.metricsList.get(currentIdx) + metrics("outputVectors") += joinMetricsData.outputVectors + metrics("inputWaitTime") += (joinMetricsData.inputWaitTime / 1000L).toLong + metrics("outputWaitTime") += (joinMetricsData.outputWaitTime / 1000L).toLong + totalTime += joinMetricsData.time + + MetricsUtil + .getAllProcessorList(joinMetricsData) + .foreach( + processor => { + if (processor.name.equalsIgnoreCase("FillingRightJoinSide")) { + metrics("fillingRightJoinSideTime") += (processor.time / 1000L).toLong + } + if (processor.name.equalsIgnoreCase("FilterTransform")) { + metrics("conditionTime") += (processor.time / 1000L).toLong + } + if (processor.name.equalsIgnoreCase("JoiningTransform")) { + metrics("probeTime") += (processor.time / 1000L).toLong + } + if ( + !BroadcastNestedLoopJoinMetricsUpdater.INCLUDING_PROCESSORS.contains( + processor.name) + ) { + metrics("extraTime") += (processor.time / 1000L).toLong + } + if ( + BroadcastNestedLoopJoinMetricsUpdater.CH_PLAN_NODE_NAME.contains(processor.name) + ) { + metrics("numOutputRows") += processor.outputRows + metrics("outputBytes") += processor.outputBytes + metrics("numInputRows") += processor.inputRows + metrics("inputBytes") += processor.inputBytes + } + }) + + currentIdx -= 1 + + // post projection + if (joinParams.postProjectionNeeded) { + metrics("postProjectTime") += + (operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong + metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors + totalTime += operatorMetrics.metricsList.get(currentIdx).time + currentIdx -= 1 + } + metrics("totalTime") += (totalTime / 1000L).toLong + } + } + } catch { + case e: Exception => + logError(s"Updating native metrics failed due to ${e.getCause}.") + throw e + } + } +} + +object BroadcastNestedLoopJoinMetricsUpdater { + val INCLUDING_PROCESSORS = Array("JoiningTransform", "FillingRightJoinSide", "FilterTransform") + val CH_PLAN_NODE_NAME = Array("JoiningTransform") +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala index bcdc1f5ef514..ed2e5c10e6c6 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala @@ -57,16 +57,13 @@ abstract class GlutenClickHouseTPCDSAbstractSuite } val noFallBack = queryNum match { case i - if i == 10 || i == 16 || i == 28 || i == 35 || i == 45 || i == 77 || - i == 88 || i == 90 || i == 94 => + if i == 10 || i == 16 || i == 35 || i == 45 || i == 77 || + i == 94 => // Q10 BroadcastHashJoin, ExistenceJoin // Q16 ShuffledHashJoin, NOT condition - // Q28 BroadcastNestedLoopJoin // Q35 BroadcastHashJoin, ExistenceJoin // Q45 BroadcastHashJoin, ExistenceJoin // Q77 CartesianProduct - // Q88 BroadcastNestedLoopJoin - // Q90 BroadcastNestedLoopJoin // Q94 BroadcastHashJoin, LeftSemi, NOT condition (false, false) case j if j == 38 || j == 87 => diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 800039f1d262..1889ed519fa0 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -128,6 +128,27 @@ DB::Block BlockUtil::buildHeader(const DB::NamesAndTypesList & names_types_list) return DB::Block(cols); } +/// The column names may be different in two blocks. +/// and the nullability also could be different, with TPCDS-Q1 as an example. +DB::ColumnWithTypeAndName +BlockUtil::convertColumnAsNecessary(const DB::ColumnWithTypeAndName & column, const DB::ColumnWithTypeAndName & sample_column) +{ + if (sample_column.type->equals(*column.type)) + return {column.column, column.type, sample_column.name}; + else if (sample_column.type->isNullable() && !column.type->isNullable() && DB::removeNullable(sample_column.type)->equals(*column.type)) + { + auto nullable_column = column; + DB::JoinCommon::convertColumnToNullable(nullable_column); + return {nullable_column.column, sample_column.type, sample_column.name}; + } + else + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, + "Columns have different types. original:{} expected:{}", + column.dumpStructure(), + sample_column.dumpStructure()); +} + /** * There is a special case with which we need be careful. In spark, struct/map/list are always * wrapped in Nullable, but this should not happen in clickhouse. diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 3ac0f63ce10b..ff25415c0cbb 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -47,6 +47,8 @@ static const std::unordered_set LONG_VALUE_SETTINGS{ class BlockUtil { public: + static constexpr auto RIHGT_COLUMN_PREFIX = "broadcast_right_"; + // Build a header block with a virtual column which will be // use to indicate the number of rows in a block. // Commonly seen in the following quries: @@ -72,6 +74,10 @@ class BlockUtil const std::unordered_set & columns_to_skip_flatten = {}); static DB::Block concatenateBlocksMemoryEfficiently(std::vector && blocks); + + /// The column names may be different in two blocks. + /// and the nullability also could be different, with TPCDS-Q1 as an example. + static DB::ColumnWithTypeAndName convertColumnAsNecessary(const DB::ColumnWithTypeAndName & column, const DB::ColumnWithTypeAndName & sample_column); }; class PODArrayUtil diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp index 1c79a00a7c4c..3f3c7e6c32aa 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp @@ -15,16 +15,18 @@ * limitations under the License. */ #include "BroadCastJoinBuilder.h" + #include #include #include #include #include +#include #include #include #include #include -#include +#include #include #include @@ -52,6 +54,20 @@ jlong callJavaGet(const std::string & id) return result; } +DB::Block resetBuildTableBlockName(Block & block, bool only_one = false) +{ + DB::ColumnsWithTypeAndName new_cols; + for (const auto & col : block) + { + // Add a prefix to avoid column name conflicts with left table. + new_cols.emplace_back(col.column, col.type, BlockUtil::RIHGT_COLUMN_PREFIX + col.name); + + if (only_one) + break; + } + return DB::Block(new_cols); +} + void cleanBuildHashTable(const std::string & hash_table_id, jlong instance) { /// Thread status holds raw pointer on query context, thus it always must be destroyed @@ -88,7 +104,8 @@ std::shared_ptr buildJoin( auto join_key_list = Poco::StringTokenizer(join_keys, ","); Names key_names; for (const auto & key_name : join_key_list) - key_names.emplace_back(key_name); + key_names.emplace_back(BlockUtil::RIHGT_COLUMN_PREFIX + key_name); + DB::JoinKind kind; DB::JoinStrictness strictness; @@ -97,10 +114,44 @@ std::shared_ptr buildJoin( substrait::NamedStruct substrait_struct; substrait_struct.ParseFromString(named_struct); Block header = TypeParser::buildBlockFromNamedStruct(substrait_struct); + header = resetBuildTableBlockName(header); + + Blocks data; + { + bool header_empty = header.getNamesAndTypesList().empty(); + bool only_one_column = header_empty; + NativeReader block_stream(input); + ProfileInfo info; + while (Block block = block_stream.read()) + { + if (header_empty) + { + // In bnlj, buidside output maybe empty, + // we use buildside header only for loop + // Like: select count(*) from t1 left join t2 + header = resetBuildTableBlockName(block, true); + header_empty = false; + } + + DB::ColumnsWithTypeAndName columns; + for (size_t i = 0; i < block.columns(); ++i) + { + const auto & column = block.getByPosition(i); + columns.emplace_back(BlockUtil::convertColumnAsNecessary(column, header.getByPosition(i))); + if (only_one_column) + break; + } + + DB::Block final_block(columns); + info.update(final_block); + data.emplace_back(std::move(final_block)); + } + } + ColumnsDescription columns_description(header.getNamesAndTypesList()); return make_shared( - input, + data, row_count, key_names, true, diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp index 326e11a84f81..2f5afd434b41 100644 --- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp +++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp @@ -16,12 +16,10 @@ */ #include "StorageJoinFromReadBuffer.h" -#include #include #include #include -#include -#include +#include #include #include @@ -43,8 +41,6 @@ extern const int DEADLOCK_AVOIDED; using namespace DB; -constexpr auto RIHGT_COLUMN_PREFIX = "broadcast_right_"; - DB::Block rightSampleBlock(bool use_nulls, const StorageInMemoryMetadata & storage_metadata_, JoinKind kind) { DB::ColumnsWithTypeAndName new_cols; @@ -52,7 +48,7 @@ DB::Block rightSampleBlock(bool use_nulls, const StorageInMemoryMetadata & stora for (const auto & col : block) { // Add a prefix to avoid column name conflicts with left table. - new_cols.emplace_back(col.column, col.type, RIHGT_COLUMN_PREFIX + col.name); + new_cols.emplace_back(col.column, col.type, col.name); if (use_nulls && isLeftOrFull(kind)) { auto & new_col = new_cols.back(); @@ -66,7 +62,7 @@ namespace local_engine { StorageJoinFromReadBuffer::StorageJoinFromReadBuffer( - DB::ReadBuffer & in, + Blocks & data, size_t row_count_, const Names & key_names_, bool use_nulls_, @@ -77,7 +73,7 @@ StorageJoinFromReadBuffer::StorageJoinFromReadBuffer( const ConstraintsDescription & constraints, const String & comment, const bool overwrite_) - : key_names({}), use_nulls(use_nulls_), row_count(row_count_), overwrite(overwrite_) + : key_names(key_names_), use_nulls(use_nulls_), row_count(row_count_), overwrite(overwrite_) { storage_metadata.setColumns(columns); storage_metadata.setConstraints(constraints); @@ -86,74 +82,33 @@ StorageJoinFromReadBuffer::StorageJoinFromReadBuffer( for (const auto & key : key_names_) if (!storage_metadata.getColumns().hasPhysical(key)) throw Exception(ErrorCodes::NO_SUCH_COLUMN_IN_TABLE, "Key column ({}) does not exist in table declaration.", key); - for (const auto & name : key_names_) - key_names.push_back(RIHGT_COLUMN_PREFIX + name); auto table_join = std::make_shared(SizeLimits(), true, kind, strictness, key_names); - right_sample_block = rightSampleBlock(use_nulls, storage_metadata, table_join->kind()); - /// If there is mixed join conditions, need to build the hash join lazily, which rely on the real table join. - if (!has_mixed_join_condition) - buildJoin(in, right_sample_block, table_join); - else - collectAllInputs(in, right_sample_block); -} -/// The column names may be different in two blocks. -/// and the nullability also could be different, with TPCDS-Q1 as an example. -static DB::ColumnWithTypeAndName convertColumnAsNecessary(const DB::ColumnWithTypeAndName & column, const DB::ColumnWithTypeAndName & sample_column) -{ - if (sample_column.type->equals(*column.type)) - return {column.column, column.type, sample_column.name}; - else if ( - sample_column.type->isNullable() && !column.type->isNullable() - && DB::removeNullable(sample_column.type)->equals(*column.type)) + if (key_names.empty()) { - auto nullable_column = column; - DB::JoinCommon::convertColumnToNullable(nullable_column); - return {nullable_column.column, sample_column.type, sample_column.name}; + // For bnlj cross join, keys clauses should be empty. + table_join->resetKeys(); } + + right_sample_block = rightSampleBlock(use_nulls, storage_metadata, table_join->kind()); + /// If there is mixed join conditions, need to build the hash join lazily, which rely on the real table join. + if (!has_mixed_join_condition) + buildJoin(data, right_sample_block, table_join); else - throw DB::Exception( - DB::ErrorCodes::LOGICAL_ERROR, - "Columns have different types. original:{} expected:{}", - column.dumpStructure(), - sample_column.dumpStructure()); + collectAllInputs(data, right_sample_block); } -void StorageJoinFromReadBuffer::buildJoin(DB::ReadBuffer & in, const Block header, std::shared_ptr analyzed_join) +void StorageJoinFromReadBuffer::buildJoin(Blocks & data, const Block header, std::shared_ptr analyzed_join) { - local_engine::NativeReader block_stream(in); - ProfileInfo info; join = std::make_shared(analyzed_join, header, overwrite, row_count); - while (Block block = block_stream.read()) - { - DB::ColumnsWithTypeAndName columns; - for (size_t i = 0; i < block.columns(); ++i) - { - const auto & column = block.getByPosition(i); - columns.emplace_back(convertColumnAsNecessary(column, header.getByPosition(i))); - } - DB::Block final_block(columns); - info.update(final_block); - join->addBlockToJoin(final_block, true); - } + for (Block block : data) + join->addBlockToJoin(std::move(block), true); } -void StorageJoinFromReadBuffer::collectAllInputs(DB::ReadBuffer & in, const DB::Block header) +void StorageJoinFromReadBuffer::collectAllInputs(Blocks & data, const DB::Block) { - local_engine::NativeReader block_stream(in); - ProfileInfo info; - while (Block block = block_stream.read()) - { - DB::ColumnsWithTypeAndName columns; - for (size_t i = 0; i < block.columns(); ++i) - { - const auto & column = block.getByPosition(i); - columns.emplace_back(convertColumnAsNecessary(column, header.getByPosition(i))); - } - DB::Block final_block(columns); - info.update(final_block); - input_blocks.emplace_back(std::move(final_block)); - } + for (Block block : data) + input_blocks.emplace_back(std::move(block)); } void StorageJoinFromReadBuffer::buildJoinLazily(DB::Block header, std::shared_ptr analyzed_join) @@ -174,7 +129,7 @@ void StorageJoinFromReadBuffer::buildJoinLazily(DB::Block header, std::shared_pt for (size_t i = 0; i < block.columns(); ++i) { const auto & column = block.getByPosition(i); - columns.emplace_back(convertColumnAsNecessary(column, header.getByPosition(i))); + columns.emplace_back(BlockUtil::convertColumnAsNecessary(column, header.getByPosition(i))); } DB::Block final_block(columns); join->addBlockToJoin(final_block, true); diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h index ddefda69c30f..600210e66869 100644 --- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h +++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h @@ -35,7 +35,7 @@ class StorageJoinFromReadBuffer { public: StorageJoinFromReadBuffer( - DB::ReadBuffer & in_, + DB::Blocks & data, size_t row_count, const DB::Names & key_names_, bool use_nulls_, @@ -65,8 +65,8 @@ class StorageJoinFromReadBuffer std::shared_ptr join = nullptr; void readAllBlocksFromInput(DB::ReadBuffer & in); - void buildJoin(DB::ReadBuffer & in, const DB::Block header, std::shared_ptr analyzed_join); - void collectAllInputs(DB::ReadBuffer & in, const DB::Block header); + void buildJoin(DB::Blocks & data, const DB::Block header, std::shared_ptr analyzed_join); + void collectAllInputs(DB::Blocks & data, const DB::Block header); void buildJoinLazily(DB::Block header, std::shared_ptr analyzed_join); }; } diff --git a/cpp-ch/local-engine/Parser/CrossRelParser.cpp b/cpp-ch/local-engine/Parser/CrossRelParser.cpp new file mode 100644 index 000000000000..071588a9f26e --- /dev/null +++ b/cpp-ch/local-engine/Parser/CrossRelParser.cpp @@ -0,0 +1,573 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "CrossRelParser.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int UNKNOWN_TYPE; + extern const int BAD_ARGUMENTS; +} +} + +struct JoinOptimizationInfo +{ + bool is_broadcast = false; + bool is_smj = false; + bool is_null_aware_anti_join = false; + bool is_existence_join = false; + std::string storage_join_key; +}; + +using namespace DB; + + +String parseJoinOptimizationInfos(const substrait::CrossRel & join) +{ + google::protobuf::StringValue optimization; + optimization.ParseFromString(join.advanced_extension().optimization().value()); + JoinOptimizationInfo info; + auto a = optimization.value(); + String storage_join_key; + ReadBufferFromString in(optimization.value()); + assertString("JoinParameters:", in); + assertString("buildHashTableId=", in); + readString(storage_join_key, in); + return storage_join_key; +} + +void reorderJoinOutput2(DB::QueryPlan & plan, DB::Names cols) +{ + ActionsDAGPtr project = std::make_shared(plan.getCurrentDataStream().header.getNamesAndTypesList()); + NamesWithAliases project_cols; + for (const auto & col : cols) + { + project_cols.emplace_back(NameWithAlias(col, col)); + } + project->project(project_cols); + QueryPlanStepPtr project_step = std::make_unique(plan.getCurrentDataStream(), project); + project_step->setStepDescription("Reorder Join Output"); + plan.addStep(std::move(project_step)); +} + +namespace local_engine +{ + +std::pair getJoinKindAndStrictness2(substrait::CrossRel_JoinType join_type) +{ + switch (join_type) + { + case substrait::CrossRel_JoinType_JOIN_TYPE_INNER: + case substrait::CrossRel_JoinType_JOIN_TYPE_LEFT: + case substrait::CrossRel_JoinType_JOIN_TYPE_OUTER: + return {DB::JoinKind::Cross, DB::JoinStrictness::All}; + // case substrait::CrossRel_JoinType_JOIN_TYPE_LEFT: + // return {DB::JoinKind::Left, DB::JoinStrictness::All}; + // + // case substrait::CrossRel_JoinType_JOIN_TYPE_RIGHT: + // return {DB::JoinKind::Right, DB::JoinStrictness::All}; + // case substrait::CrossRel_JoinType_JOIN_TYPE_OUTER: + // return {DB::JoinKind::Full, DB::JoinStrictness::All}; + default: + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); + } +} +std::shared_ptr createDefaultTableJoin2(substrait::CrossRel_JoinType join_type) +{ + auto & global_context = SerializedPlanParser::global_context; + auto table_join = std::make_shared( + global_context->getSettings(), global_context->getGlobalTemporaryVolume(), global_context->getTempDataOnDisk()); + + std::pair kind_and_strictness = getJoinKindAndStrictness2(join_type); + table_join->setKind(kind_and_strictness.first); + table_join->setStrictness(kind_and_strictness.second); + return table_join; +} + +CrossRelParser::CrossRelParser(SerializedPlanParser * plan_paser_) + : RelParser(plan_paser_) + , function_mapping(plan_paser_->function_mapping) + , context(plan_paser_->context) + , extra_plan_holder(plan_paser_->extra_plan_holder) +{ +} + +DB::QueryPlanPtr +CrossRelParser::parse(DB::QueryPlanPtr /*query_plan*/, const substrait::Rel & /*rel*/, std::list & /*rel_stack_*/) +{ + throw Exception(ErrorCodes::LOGICAL_ERROR, "join node has 2 inputs, can't call parse()."); +} + +const substrait::Rel & CrossRelParser::getSingleInput(const substrait::Rel & /*rel*/) +{ + throw Exception(ErrorCodes::LOGICAL_ERROR, "join node has 2 inputs, can't call getSingleInput()."); +} + +DB::QueryPlanPtr CrossRelParser::parseOp(const substrait::Rel & rel, std::list & rel_stack) +{ + const auto & join = rel.cross(); + if (!join.has_left() || !join.has_right()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "left table or right table is missing."); + } + + rel_stack.push_back(&rel); + auto left_plan = getPlanParser()->parseOp(join.left(), rel_stack); + auto right_plan = getPlanParser()->parseOp(join.right(), rel_stack); + rel_stack.pop_back(); + + return parseJoin(join, std::move(left_plan), std::move(right_plan)); +} + +std::unordered_set CrossRelParser::extractTableSidesFromExpression(const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header) +{ + std::unordered_set table_sides; + if (expr.has_scalar_function()) + { + for (const auto & arg : expr.scalar_function().arguments()) + { + auto table_sides_from_arg = extractTableSidesFromExpression(arg.value(), left_header, right_header); + table_sides.insert(table_sides_from_arg.begin(), table_sides_from_arg.end()); + } + } + else if (expr.has_selection() && expr.selection().has_direct_reference() && expr.selection().direct_reference().has_struct_field()) + { + auto pos = expr.selection().direct_reference().struct_field().field(); + if (pos < left_header.columns()) + { + table_sides.insert(DB::JoinTableSide::Left); + } + else + { + table_sides.insert(DB::JoinTableSide::Right); + } + } + else if (expr.has_singular_or_list()) + { + auto child_table_sides = extractTableSidesFromExpression(expr.singular_or_list().value(), left_header, right_header); + table_sides.insert(child_table_sides.begin(), child_table_sides.end()); + for (const auto & option : expr.singular_or_list().options()) + { + child_table_sides = extractTableSidesFromExpression(option, left_header, right_header); + table_sides.insert(child_table_sides.begin(), child_table_sides.end()); + } + } + else if (expr.has_cast()) + { + auto child_table_sides = extractTableSidesFromExpression(expr.cast().input(), left_header, right_header); + table_sides.insert(child_table_sides.begin(), child_table_sides.end()); + } + else if (expr.has_if_then()) + { + for (const auto & if_child : expr.if_then().ifs()) + { + auto child_table_sides = extractTableSidesFromExpression(if_child.if_(), left_header, right_header); + table_sides.insert(child_table_sides.begin(), child_table_sides.end()); + child_table_sides = extractTableSidesFromExpression(if_child.then(), left_header, right_header); + table_sides.insert(child_table_sides.begin(), child_table_sides.end()); + } + auto child_table_sides = extractTableSidesFromExpression(expr.if_then().else_(), left_header, right_header); + table_sides.insert(child_table_sides.begin(), child_table_sides.end()); + } + else if (expr.has_literal()) + { + // nothing + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Illegal expression '{}'", expr.DebugString()); + } + return table_sides; +} + + +void CrossRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & right, const StorageJoinFromReadBuffer & storage_join) +{ + /// To support mixed join conditions, we must make sure that the column names in the right be the same as + /// storage_join's right sample block. + ActionsDAGPtr project = ActionsDAG::makeConvertingActions( + right.getCurrentDataStream().header.getColumnsWithTypeAndName(), + storage_join.getRightSampleBlock().getColumnsWithTypeAndName(), + ActionsDAG::MatchColumnsMode::Position); + + if (project) + { + QueryPlanStepPtr project_step = std::make_unique(right.getCurrentDataStream(), project); + project_step->setStepDescription("Rename Broadcast Table Name"); + steps.emplace_back(project_step.get()); + right.addStep(std::move(project_step)); + } + + /// If the columns name in right table is duplicated with left table, we need to rename the left table's columns, + /// avoid the columns name in the right table be changed in `addConvertStep`. + /// This could happen in tpc-ds q44. + DB::ColumnsWithTypeAndName new_left_cols; + const auto & right_header = right.getCurrentDataStream().header; + auto left_prefix = getUniqueName("left"); + for (const auto & col : left.getCurrentDataStream().header) + { + if (right_header.has(col.name)) + { + new_left_cols.emplace_back(col.column, col.type, left_prefix + col.name); + } + else + { + new_left_cols.emplace_back(col.column, col.type, col.name); + } + } + project = ActionsDAG::makeConvertingActions( + left.getCurrentDataStream().header.getColumnsWithTypeAndName(), + new_left_cols, + ActionsDAG::MatchColumnsMode::Position); + + if (project) + { + QueryPlanStepPtr project_step = std::make_unique(left.getCurrentDataStream(), project); + project_step->setStepDescription("Rename Left Table Name for broadcast join"); + steps.emplace_back(project_step.get()); + left.addStep(std::move(project_step)); + } +} + +DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right) +{ + auto storage_join_key = parseJoinOptimizationInfos(join); + auto storage_join = BroadCastJoinBuilder::getJoin(storage_join_key) ; + renamePlanColumns(*left, *right, *storage_join); + auto table_join = createDefaultTableJoin2(join.type()); + DB::Block right_header_before_convert_step = right->getCurrentDataStream().header; + addConvertStep(*table_join, *left, *right); + + // Add a check to find error easily. + if (storage_join) + { + if(!blocksHaveEqualStructure(right_header_before_convert_step, right->getCurrentDataStream().header)) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}", + left->getCurrentDataStream().header.dumpNames(), + right_header_before_convert_step.dumpNames(), + right->getCurrentDataStream().header.dumpNames()); + } + } + + Names after_join_names; + auto left_names = left->getCurrentDataStream().header.getNames(); + after_join_names.insert(after_join_names.end(), left_names.begin(), left_names.end()); + auto right_name = table_join->columnsFromJoinedTable().getNames(); + after_join_names.insert(after_join_names.end(), right_name.begin(), right_name.end()); + + auto left_header = left->getCurrentDataStream().header; + auto right_header = right->getCurrentDataStream().header; + + QueryPlanPtr query_plan; + // applyJoinFilter(*table_join, join, *left, *right, true); + table_join->addDisjunct(); + auto broadcast_hash_join = storage_join->getJoinLocked(table_join, context); + // table_join->resetKeys(); + QueryPlanStepPtr join_step = std::make_unique(left->getCurrentDataStream(), broadcast_hash_join, 8192); + + join_step->setStepDescription("STORAGE_JOIN"); + steps.emplace_back(join_step.get()); + left->addStep(std::move(join_step)); + query_plan = std::move(left); + /// hold right plan for profile + extra_plan_holder.emplace_back(std::move(right)); + + addPostFilter(*query_plan, join); + reorderJoinOutput2(*query_plan, after_join_names); + + return query_plan; +} + + +void CrossRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::CrossRel & join_rel) +{ + if (!join_rel.has_expression()) + return; + + auto expression = join_rel.expression(); + std::string filter_name; + auto actions_dag = std::make_shared(query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName()); + if (!expression.has_scalar_function()) + { + // It may be singular_or_list + auto * in_node = getPlanParser()->parseExpression(actions_dag, expression); + filter_name = in_node->result_name; + } + else + { + getPlanParser()->parseFunction(query_plan.getCurrentDataStream().header, expression, filter_name, actions_dag, true); + } + auto filter_step = std::make_unique(query_plan.getCurrentDataStream(), actions_dag, filter_name, true); + filter_step->setStepDescription("Post Join Filter"); + steps.emplace_back(filter_step.get()); + query_plan.addStep(std::move(filter_step)); +} + +bool CrossRelParser::applyJoinFilter( + DB::TableJoin & table_join, const substrait::CrossRel & join_rel, DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition) +{ + if (!join_rel.has_expression()) + return true; + const auto & expr = join_rel.expression(); + + const auto & left_header = left.getCurrentDataStream().header; + const auto & right_header = right.getCurrentDataStream().header; + ColumnsWithTypeAndName mixed_columns; + std::unordered_set added_column_name; + for (const auto & col : left_header.getColumnsWithTypeAndName()) + { + mixed_columns.emplace_back(col); + added_column_name.insert(col.name); + } + for (const auto & col : right_header.getColumnsWithTypeAndName()) + { + const auto & renamed_col_name = table_join.renamedRightColumnNameWithAlias(col.name); + if (added_column_name.find(col.name) != added_column_name.end()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Right column's name conflict with left column: {}", col.name); + mixed_columns.emplace_back(col); + added_column_name.insert(col.name); + } + DB::Block mixed_header(mixed_columns); + + auto table_sides = extractTableSidesFromExpression(expr, left_header, right_header); + + auto get_input_expressions = [](const DB::Block & header) + { + std::vector exprs; + for (size_t i = 0; i < header.columns(); ++i) + { + substrait::Expression expr; + expr.mutable_selection()->mutable_direct_reference()->mutable_struct_field()->set_field(i); + exprs.emplace_back(expr); + } + return exprs; + }; + + /// If the columns in the expression are all from one table, use analyzer_left_filter_condition_column_name + /// and analyzer_left_filter_condition_column_name to filt the join result data. It requires to build the filter + /// column at first. + /// If the columns in the expression are from both tables, use mixed_join_expression to filt the join result data. + /// the filter columns will be built inner the join step. + if (table_sides.size() == 1) + { + auto table_side = *table_sides.begin(); + if (table_side == DB::JoinTableSide::Left) + { + auto input_exprs = get_input_expressions(left_header); + input_exprs.push_back(expr); + auto actions_dag = expressionsToActionsDAG(input_exprs, left_header); + table_join.getClauses().back().analyzer_left_filter_condition_column_name = actions_dag->getOutputs().back()->result_name; + QueryPlanStepPtr before_join_step = std::make_unique(left.getCurrentDataStream(), actions_dag); + before_join_step->setStepDescription("Before JOIN LEFT"); + steps.emplace_back(before_join_step.get()); + left.addStep(std::move(before_join_step)); + } + else + { + /// since the field reference in expr is the index of left_header ++ right_header, so we use + /// mixed_header to build the actions_dag + auto input_exprs = get_input_expressions(mixed_header); + input_exprs.push_back(expr); + auto actions_dag = expressionsToActionsDAG(input_exprs, mixed_header); + + /// clear unused columns in actions_dag + for (const auto & col : left_header.getColumnsWithTypeAndName()) + { + actions_dag->removeUnusedResult(col.name); + } + actions_dag->removeUnusedActions(); + + table_join.getClauses().back().analyzer_right_filter_condition_column_name = actions_dag->getOutputs().back()->result_name; + QueryPlanStepPtr before_join_step = std::make_unique(right.getCurrentDataStream(), actions_dag); + before_join_step->setStepDescription("Before JOIN RIGHT"); + steps.emplace_back(before_join_step.get()); + right.addStep(std::move(before_join_step)); + } + } + else if (table_sides.size() == 2) + { + if (!allow_mixed_condition) + return false; + auto mixed_join_expressions_actions = expressionsToActionsDAG({expr}, mixed_header); + table_join.getMixedJoinExpression() + = std::make_shared(mixed_join_expressions_actions, ExpressionActionsSettings::fromContext(context)); + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not any table column is used in the join condition.\n{}", join_rel.DebugString()); + } + return true; +} + +void CrossRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right) +{ + /// If the columns name in right table is duplicated with left table, we need to rename the right table's columns. + NameSet left_columns_set; + for (const auto & col : left.getCurrentDataStream().header.getNames()) + left_columns_set.emplace(col); + table_join.setColumnsFromJoinedTable( + right.getCurrentDataStream().header.getNamesAndTypesList(), left_columns_set, getUniqueName("right") + "."); + + // fix right table key duplicate + NamesWithAliases right_table_alias; + for (size_t idx = 0; idx < table_join.columnsFromJoinedTable().size(); idx++) + { + auto origin_name = right.getCurrentDataStream().header.getByPosition(idx).name; + auto dedup_name = table_join.columnsFromJoinedTable().getNames().at(idx); + if (origin_name != dedup_name) + { + right_table_alias.emplace_back(NameWithAlias(origin_name, dedup_name)); + } + } + if (!right_table_alias.empty()) + { + ActionsDAGPtr rename_dag = std::make_shared(right.getCurrentDataStream().header.getNamesAndTypesList()); + auto original_right_columns = right.getCurrentDataStream().header; + for (const auto & column_alias : right_table_alias) + { + if (original_right_columns.has(column_alias.first)) + { + auto pos = original_right_columns.getPositionByName(column_alias.first); + const auto & alias = rename_dag->addAlias(*rename_dag->getInputs()[pos], column_alias.second); + rename_dag->getOutputs()[pos] = &alias; + } + } + + QueryPlanStepPtr project_step = std::make_unique(right.getCurrentDataStream(), rename_dag); + project_step->setStepDescription("Right Table Rename"); + steps.emplace_back(project_step.get()); + right.addStep(std::move(project_step)); + } + + for (const auto & column : table_join.columnsFromJoinedTable()) + { + table_join.addJoinedColumn(column); + } + ActionsDAGPtr left_convert_actions = nullptr; + ActionsDAGPtr right_convert_actions = nullptr; + std::tie(left_convert_actions, right_convert_actions) = table_join.createConvertingActions( + left.getCurrentDataStream().header.getColumnsWithTypeAndName(), right.getCurrentDataStream().header.getColumnsWithTypeAndName()); + + if (right_convert_actions) + { + auto converting_step = std::make_unique(right.getCurrentDataStream(), right_convert_actions); + converting_step->setStepDescription("Convert joined columns"); + steps.emplace_back(converting_step.get()); + right.addStep(std::move(converting_step)); + } + + if (left_convert_actions) + { + auto converting_step = std::make_unique(left.getCurrentDataStream(), left_convert_actions); + converting_step->setStepDescription("Convert joined columns"); + steps.emplace_back(converting_step.get()); + left.addStep(std::move(converting_step)); + } +} + +/// Join keys are collected from substrait::JoinRel::expression() which only contains the equal join conditions. +void CrossRelParser::collectJoinKeys( + TableJoin & table_join, const substrait::CrossRel & join_rel, const DB::Block & left_header, const DB::Block & right_header) +{ + if (!join_rel.has_expression()) + return; + const auto & expr = join_rel.expression(); + auto & join_clause = table_join.getClauses().back(); + std::list expressions_stack; + expressions_stack.push_back(&expr); + while (!expressions_stack.empty()) + { + /// Must handle the expressions in DF order. It matters in sort merge join. + const auto * current_expr = expressions_stack.back(); + expressions_stack.pop_back(); + if (!current_expr->has_scalar_function()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Function expression is expected"); + auto function_name = parseFunctionName(current_expr->scalar_function().function_reference(), current_expr->scalar_function()); + if (!function_name) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid function expression"); + if (*function_name == "equals") + { + String left_key, right_key; + size_t left_pos = 0, right_pos = 0; + for (const auto & arg : current_expr->scalar_function().arguments()) + { + if (!arg.value().has_selection() || !arg.value().selection().has_direct_reference() + || !arg.value().selection().direct_reference().has_struct_field()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A column reference is expected"); + } + auto col_pos_ref = arg.value().selection().direct_reference().struct_field().field(); + if (col_pos_ref < left_header.columns()) + { + left_pos = col_pos_ref; + left_key = left_header.getByPosition(col_pos_ref).name; + } + else + { + right_pos = col_pos_ref - left_header.columns(); + right_key = right_header.getByPosition(col_pos_ref - left_header.columns()).name; + } + } + if (left_key.empty() || right_key.empty()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid key equal join condition"); + join_clause.addKey(left_key, right_key, false); + } + else if (*function_name == "and") + { + expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(1).value()); + expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(0).value()); + } + else if (*function_name == "not") + { + expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(0).value()); + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow function: {}", *function_name); + } + } +} + +void registerCrossRelParser(RelParserFactory & factory) +{ + auto builder = [](SerializedPlanParser * plan_paser) { return std::make_shared(plan_paser); }; + factory.registerBuilder(substrait::Rel::RelTypeCase::kCross, builder); +} + +} diff --git a/cpp-ch/local-engine/Parser/CrossRelParser.h b/cpp-ch/local-engine/Parser/CrossRelParser.h new file mode 100644 index 000000000000..9766b4e91d24 --- /dev/null +++ b/cpp-ch/local-engine/Parser/CrossRelParser.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace DB +{ +class TableJoin; +} + +namespace local_engine +{ + +class StorageJoinFromReadBuffer; + +std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type); + +class CrossRelParser : public RelParser +{ +public: + explicit CrossRelParser(SerializedPlanParser * plan_paser_); + ~CrossRelParser() override = default; + + DB::QueryPlanPtr + parse(DB::QueryPlanPtr query_plan, const substrait::Rel & sort_rel, std::list & rel_stack_) override; + + DB::QueryPlanPtr parseOp(const substrait::Rel & rel, std::list & rel_stack) override; + + const substrait::Rel & getSingleInput(const substrait::Rel & rel) override; + +private: + std::unordered_map & function_mapping; + ContextPtr & context; + std::vector & extra_plan_holder; + + + DB::QueryPlanPtr parseJoin(const substrait::CrossRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right); + void renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & right, const StorageJoinFromReadBuffer & storage_join); + void addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right); + void addPostFilter(DB::QueryPlan & query_plan, const substrait::CrossRel & join); + void collectJoinKeys( + TableJoin & table_join, const substrait::CrossRel & join_rel, const DB::Block & left_header, const DB::Block & right_header); + bool applyJoinFilter( + DB::TableJoin & table_join, const substrait::CrossRel & join_rel, DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition); + static std::unordered_set extractTableSidesFromExpression( + const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); +}; + +} diff --git a/cpp-ch/local-engine/Parser/RelParser.cpp b/cpp-ch/local-engine/Parser/RelParser.cpp index 282339c4d641..f651146a391d 100644 --- a/cpp-ch/local-engine/Parser/RelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParser.cpp @@ -156,6 +156,7 @@ void registerAggregateParser(RelParserFactory & factory); void registerProjectRelParser(RelParserFactory & factory); void registerJoinRelParser(RelParserFactory & factory); void registerFilterRelParser(RelParserFactory & factory); +void registerCrossRelParser(RelParserFactory & factory); void registerRelParsers() { @@ -166,6 +167,7 @@ void registerRelParsers() registerAggregateParser(factory); registerProjectRelParser(factory); registerJoinRelParser(factory); + registerCrossRelParser(factory); registerFilterRelParser(factory); } } diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 8c60c6e500a9..1174faf1fe58 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -528,6 +528,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list case substrait::Rel::RelTypeCase::kSort: case substrait::Rel::RelTypeCase::kWindow: case substrait::Rel::RelTypeCase::kJoin: + case substrait::Rel::RelTypeCase::kCross: case substrait::Rel::RelTypeCase::kExpand: { auto op_parser = RelParserFactory::instance().getBuilder(rel.rel_type_case())(this); query_plan = op_parser->parseOp(rel, rel_stack); diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 477fdb1f6d44..6364faae1513 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -255,6 +255,7 @@ class SerializedPlanParser friend class FunctionExecutor; friend class NonNullableColumnsResolver; friend class JoinRelParser; + friend class CrossRelParser; friend class MergeTreeRelParser; friend class ProjectRelParser; diff --git a/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto b/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto index 0e51baf5ad4c..3813de868445 100644 --- a/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto +++ b/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto @@ -259,6 +259,7 @@ message CrossRel { JOIN_TYPE_OUTER = 2; JOIN_TYPE_LEFT = 3; JOIN_TYPE_RIGHT = 4; + JOIN_TYPE_LEFT_SEMI = 5; } substrait.extensions.AdvancedExtension advanced_extension = 10; diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index d159486373ac..8ddcc7b7f93e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -144,7 +144,7 @@ trait BackendSettingsApi { def supportCartesianProductExec(): Boolean = false - def supportBroadcastNestedLoopJoinExec(): Boolean = false + def supportBroadcastNestedLoopJoinExec(): Boolean = true def supportSampleExec(): Boolean = false diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index 092612ea7340..241682ad867a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -19,17 +19,18 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.SubstraitContext +import org.apache.gluten.substrait.{JoinParams, SubstraitContext} import org.apache.gluten.utils.SubstraitUtil import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.BaseJoinExec import org.apache.spark.sql.execution.metric.SQLMetric +import com.google.protobuf.Any import io.substrait.proto.CrossRel abstract class BroadcastNestedLoopJoinExecTransformer( @@ -53,7 +54,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer( // Hint substrait to switch the left and right, // since we assume always build right side in substrait. - private lazy val needSwitchChildren: Boolean = buildSide match { + protected lazy val needSwitchChildren: Boolean = buildSide match { case BuildLeft => true case BuildRight => false } @@ -79,6 +80,10 @@ abstract class BroadcastNestedLoopJoinExecTransformer( left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case LeftSemi => // LeftExistence(_) + left.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) case x => throw new IllegalArgumentException(s"${getClass.getSimpleName} not take $x as the JoinType") } @@ -103,6 +108,8 @@ abstract class BroadcastNestedLoopJoinExecTransformer( } } + def genJoinParameters(): Any = Any.getDefaultInstance + override protected def doTransform(context: SubstraitContext): TransformContext = { val streamedPlanContext = streamedPlan.asInstanceOf[TransformSupport].transform(context) val (inputStreamedRelNode, inputStreamedOutput) = @@ -113,6 +120,10 @@ abstract class BroadcastNestedLoopJoinExecTransformer( (buildPlanContext.root, buildPlanContext.outputAttributes) val operatorId = context.nextOperatorId(this.nodeName) + val joinParams = new JoinParams + if (condition.isDefined) { + joinParams.isWithCondition = true + } val crossRel = JoinUtils.createCrossRel( substraitJoinType, @@ -122,9 +133,12 @@ abstract class BroadcastNestedLoopJoinExecTransformer( inputStreamedOutput, inputBuildOutput, context, - operatorId + operatorId, + genJoinParameters() ) + context.registerJoinParam(operatorId, joinParams) + val projectRelPostJoinRel = JoinUtils.createProjectRelPostJoinRel( needSwitchChildren, joinType, @@ -168,6 +182,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer( buildPlan.output, substraitContext, substraitContext.nextOperatorId(this.nodeName), + genJoinParameters(), validation = true ) doNativeValidation(substraitContext, crossRel) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala index eb2c0bfd7229..1371c10dafcc 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala @@ -133,10 +133,15 @@ object JoinUtils { def createJoinExtensionNode( joinParameters: Any, - output: Seq[Attribute]): AdvancedExtensionNode = { + output: Seq[Attribute], + validation: Boolean = false): AdvancedExtensionNode = { // Use field [optimization] in a extension node // to send some join parameters through Substrait plan. - val enhancement = createEnhancement(output) + val enhancement = if (validation) { + createEnhancement(output) + } else { + null + } ExtensionBuilder.makeAdvancedExtension(joinParameters, enhancement) } @@ -337,6 +342,7 @@ object JoinUtils { inputBuildOutput: Seq[Attribute], substraitContext: SubstraitContext, operatorId: java.lang.Long, + joinParameters: Any, validation: Boolean = false ): RelNode = { val expressionNode = condition.map { @@ -346,7 +352,7 @@ object JoinUtils { .doTransform(substraitContext.registeredFunction) } val extensionNode = - JoinUtils.createExtensionNode(inputStreamedOutput ++ inputBuildOutput, validation) + createJoinExtensionNode(joinParameters, inputStreamedOutput ++ inputBuildOutput, validation) RelBuilder.makeCrossRel( inputStreamedRelNode, diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala index 15fc8bea7054..b7a30f7e177a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala @@ -20,10 +20,14 @@ import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, Trans import org.apache.gluten.utils.{LogLevelUtil, PlanUtil} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftSemi} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeLike} +import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec import org.apache.spark.sql.internal.SQLConf object MiscColumnarRules { @@ -190,4 +194,22 @@ object MiscColumnarRules { child } } + + // Remove unnecessary bnlj like sql: + // ``` select l.* from l left semi join r; ``` + // The result always is left table. + case class RemoveBroadcastNestedLoopJoin() extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case BroadcastNestedLoopJoinExec( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) if condition.isEmpty && joinType == LeftSemi => + buildSide match { + case BuildLeft => right + case BuildRight => left + } + } + } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala index 03b2b66b09b3..801a7d22c49f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala @@ -19,7 +19,7 @@ package org.apache.gluten.extension.columnar.heuristic import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.columnar._ -import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} +import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveBroadcastNestedLoopJoin, RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} import org.apache.gluten.extension.columnar.util.AdaptiveContext @@ -108,6 +108,7 @@ class HeuristicApplier(session: SparkSession) (_: SparkSession) => FallbackEmptySchemaRelation(), (spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark), (_: SparkSession) => RewriteSparkPlanRulesManager(), + (_: SparkSession) => RemoveBroadcastNestedLoopJoin(), (_: SparkSession) => AddFallbackTagRule() ) ::: List((_: SparkSession) => TransformPreOverrides()) ::: From bdda93d64a1b507bce276f9ee2a6c5a9488a80c8 Mon Sep 17 00:00:00 2001 From: loneylee Date: Tue, 2 Jul 2024 14:26:44 +0800 Subject: [PATCH 2/6] fix metric fix ci fix velox ci tag fix rebase all fallback fix style --- .../gluten/vectorized/StorageJoinBuilder.java | 10 +- .../backendsapi/clickhouse/CHBackend.scala | 6 + .../backendsapi/clickhouse/CHMetricsApi.scala | 4 - ...oadcastNestedLoopJoinExecTransformer.scala | 72 +---- .../FallbackBroadcaseHashJoinRules.scala | 148 +++++---- ...roadcastNestedLoopJoinMetricsUpdater.scala | 20 +- .../GlutenClickHouseTPCDSAbstractSuite.scala | 8 +- ...enClickHouseTPCHSaltNullParquetSuite.scala | 10 +- .../backendsapi/velox/VeloxBackend.scala | 2 - cpp-ch/local-engine/Common/CHUtil.cpp | 51 ++- cpp-ch/local-engine/Common/CHUtil.h | 11 + .../Join/BroadCastJoinBuilder.cpp | 8 +- .../local-engine/Join/BroadCastJoinBuilder.h | 2 +- cpp-ch/local-engine/Parser/CrossRelParser.cpp | 304 ++---------------- cpp-ch/local-engine/Parser/CrossRelParser.h | 6 - cpp-ch/local-engine/Parser/JoinRelParser.cpp | 42 +-- cpp-ch/local-engine/Parser/JoinRelParser.h | 2 - cpp-ch/local-engine/local_engine_jni.cpp | 3 +- .../backendsapi/BackendSettingsApi.scala | 5 +- ...oadcastNestedLoopJoinExecTransformer.scala | 56 +++- .../columnar/heuristic/HeuristicApplier.scala | 3 +- .../columnar/validator/Validators.scala | 2 - .../apache/gluten/utils/SubstraitUtil.scala | 4 + .../ColumnarBroadcastExchangeExec.scala | 9 +- 24 files changed, 273 insertions(+), 515 deletions(-) diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java index 9cb49b6a2d30..27725998feeb 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java @@ -74,12 +74,20 @@ public static long build( return converter.genColumnNameWithExprId(attr); }) .collect(Collectors.joining(",")); + + int joinType; + if (broadCastContext.buildHashTableId().startsWith("BuiltBNLJBroadcastTable-")) { + joinType = SubstraitUtil.toCrossRelSubstrait(broadCastContext.joinType()).ordinal(); + } else { + joinType = SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(); + } + return nativeBuild( broadCastContext.buildHashTableId(), batches, rowCount, joinKey, - SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(), + joinType, broadCastContext.hasMixedFiltCondition(), toNameStruct(output).toByteArray()); } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index d369b8c1626f..b0c077dcb9db 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -297,4 +298,9 @@ object CHBackendSettings extends BackendSettingsApi with Logging { } override def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = true + + override def supportBroadcastNestedJoinJoinType: JoinType => Boolean = { + case _: InnerLike | LeftSemi | FullOuter => true + case _ => false + } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala index 59c38b02194a..5465e9b60b67 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala @@ -357,10 +357,6 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil { "extraTime" -> SQLMetrics.createTimingMetric(sparkContext, "extra operators time"), "inputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of waiting for data"), "outputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of waiting for output"), - "streamPreProjectionTime" -> - SQLMetrics.createTimingMetric(sparkContext, "time of stream side preProjection"), - "buildPreProjectionTime" -> - SQLMetrics.createTimingMetric(sparkContext, "time of build side preProjection"), "postProjectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of postProjection"), "probeTime" -> diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala index 55f86bb32cce..c408a223784b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala @@ -16,20 +16,17 @@ */ package org.apache.gluten.execution -import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.extension.ValidationResult import org.apache.spark.rdd.RDD import org.apache.spark.rpc.GlutenDriverEndpoint import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.BuildSide -import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashJoin} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch -import com.google.protobuf.{Any, StringValue} - case class CHBroadcastNestedLoopJoinExecTransformer( left: SparkPlan, right: SparkPlan, @@ -43,32 +40,6 @@ case class CHBroadcastNestedLoopJoinExecTransformer( joinType, condition ) { - // Unique ID for builded table - lazy val buildBroadcastTableId: String = "BuiltBroadcastTable-" + buildPlan.id - - lazy val (buildKeyExprs, streamedKeyExprs) = { - require( - leftKeys.length == rightKeys.length && - leftKeys - .map(_.dataType) - .zip(rightKeys.map(_.dataType)) - .forall(types => sameType(types._1, types._2)), - "Join keys from two sides should have same length and types" - ) - // Spark has an improvement which would patch integer joins keys to a Long value. - // But this improvement would cause add extra project before hash join in velox, - // disabling this improvement as below would help reduce the project. - val (lkeys, rkeys) = if (BackendsApiManager.getSettings.enableJoinKeysRewrite()) { - (HashJoin.rewriteKeyExpr(leftKeys), HashJoin.rewriteKeyExpr(rightKeys)) - } else { - (leftKeys, rightKeys) - } - if (needSwitchChildren) { - (lkeys, rkeys) - } else { - (rkeys, lkeys) - } - } override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = { val streamedRDD = getColumnarInputRDDs(streamedPlan) @@ -106,38 +77,17 @@ case class CHBroadcastNestedLoopJoinExecTransformer( res } - def sameType(from: DataType, to: DataType): Boolean = { - (from, to) match { - case (ArrayType(fromElement, _), ArrayType(toElement, _)) => - sameType(fromElement, toElement) - - case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => - sameType(fromKey, toKey) && - sameType(fromValue, toValue) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.length == toFields.length && - fromFields.zip(toFields).forall { - case (l, r) => - l.name.equalsIgnoreCase(r.name) && - sameType(l.dataType, r.dataType) + override def validateJoinTypeAndBuildSide(): ValidationResult = { + joinType match { + case _: InnerLike => + case _ => + if (condition.isDefined) { + return ValidationResult.notOk( + s"Broadcast Nested Loop join is not supported join type $joinType with conditions") } - - case (fromDataType, toDataType) => fromDataType == toDataType } - } - override def genJoinParameters(): Any = { - val joinParametersStr = new StringBuffer("JoinParameters:") - joinParametersStr - .append("buildHashTableId=") - .append(buildBroadcastTableId) - .append("\n") - val message = StringValue - .newBuilder() - .setValue(joinParametersStr.toString) - .build() - BackendsApiManager.getTransformerApiInstance.packPBMessage(message) + ValidationResult.ok } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala index 59c2d6494bdb..f7d9a6dbe7ea 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec} -import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} import scala.util.control.Breaks.{break, breakable} @@ -103,6 +103,10 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl GlutenConfig.getConf.enableColumnarBroadcastJoin && GlutenConfig.getConf.enableColumnarBroadcastExchange + private val enableColumnarBroadcastNestedLoopJoin: Boolean = + GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled && + GlutenConfig.getConf.enableColumnarBroadcastExchange + override def apply(plan: SparkPlan): SparkPlan = { plan.foreachUp { p => @@ -138,63 +142,9 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl case BuildRight => bhj.right } - val maybeExchange = buildSidePlan - .find { - case BroadcastExchangeExec(_, _) => true - case _ => false - } - .map(_.asInstanceOf[BroadcastExchangeExec]) - - maybeExchange match { - case Some(exchange @ BroadcastExchangeExec(mode, child)) => - isBhjTransformable.tagOnFallback(bhj) - if (!isBhjTransformable.isValid) { - FallbackTags.add(exchange, isBhjTransformable) - } - case None => - // we are in AQE, find the hidden exchange - // FIXME did we consider the case that AQE: OFF && Reuse: ON ? - var maybeHiddenExchange: Option[BroadcastExchangeLike] = None - breakable { - buildSidePlan.foreach { - case e: BroadcastExchangeLike => - maybeHiddenExchange = Some(e) - break - case t: BroadcastQueryStageExec => - t.plan.foreach { - case e2: BroadcastExchangeLike => - maybeHiddenExchange = Some(e2) - break - case r: ReusedExchangeExec => - r.child match { - case e2: BroadcastExchangeLike => - maybeHiddenExchange = Some(e2) - break - case _ => - } - case _ => - } - case _ => - } - } - // restriction to force the hidden exchange to be found - val exchange = maybeHiddenExchange.get - // to conform to the underlying exchange's type, columnar or vanilla - exchange match { - case BroadcastExchangeExec(mode, child) => - FallbackTags.add( - bhj, - "it's a materialized broadcast exchange or reused broadcast exchange") - case ColumnarBroadcastExchangeExec(mode, child) => - if (!isBhjTransformable.isValid) { - throw new IllegalStateException( - s"BroadcastExchange has already been" + - s" transformed to columnar version but BHJ is determined as" + - s" non-transformable: ${bhj.toString()}") - } - } - } + preTagBroadcastExchangeFallback(bhj, buildSidePlan, isBhjTransformable) } + case bnlj: BroadcastNestedLoopJoinExec => applyBNLJFallback(bnlj) case _ => } } catch { @@ -207,4 +157,88 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl } plan } + + private def applyBNLJFallback(bnlj: BroadcastNestedLoopJoinExec) = { + if (!enableColumnarBroadcastNestedLoopJoin) { + FallbackTags.add(bnlj, "columnar BroadcastJoin is not enabled in BroadcastNestedLoopJoinExec") + } + + val transformer = BackendsApiManager.getSparkPlanExecApiInstance + .genBroadcastNestedLoopJoinExecTransformer( + bnlj.left, + bnlj.right, + bnlj.buildSide, + bnlj.joinType, + bnlj.condition) + + val isBNLJTransformable = transformer.doValidate() + val buildSidePlan = bnlj.buildSide match { + case BuildLeft => bnlj.left + case BuildRight => bnlj.right + } + + preTagBroadcastExchangeFallback(bnlj, buildSidePlan, isBNLJTransformable) + } + + private def preTagBroadcastExchangeFallback( + plan: SparkPlan, + buildSidePlan: SparkPlan, + isTransformable: ValidationResult): Unit = { + val maybeExchange = buildSidePlan + .find { + case BroadcastExchangeExec(_, _) => true + case _ => false + } + .map(_.asInstanceOf[BroadcastExchangeExec]) + + maybeExchange match { + case Some(exchange @ BroadcastExchangeExec(_, _)) => + isTransformable.tagOnFallback(plan) + if (!isTransformable.isValid) { + FallbackTags.add(exchange, isTransformable) + } + case None => + // we are in AQE, find the hidden exchange + // FIXME did we consider the case that AQE: OFF && Reuse: ON ? + var maybeHiddenExchange: Option[BroadcastExchangeLike] = None + breakable { + buildSidePlan.foreach { + case e: BroadcastExchangeLike => + maybeHiddenExchange = Some(e) + break + case t: BroadcastQueryStageExec => + t.plan.foreach { + case e2: BroadcastExchangeLike => + maybeHiddenExchange = Some(e2) + break + case r: ReusedExchangeExec => + r.child match { + case e2: BroadcastExchangeLike => + maybeHiddenExchange = Some(e2) + break + case _ => + } + case _ => + } + case _ => + } + } + // restriction to force the hidden exchange to be found + val exchange = maybeHiddenExchange.get + // to conform to the underlying exchange's type, columnar or vanilla + exchange match { + case BroadcastExchangeExec(mode, child) => + FallbackTags.add( + plan, + "it's a materialized broadcast exchange or reused broadcast exchange") + case ColumnarBroadcastExchangeExec(mode, child) => + if (!isTransformable.isValid) { + throw new IllegalStateException( + s"BroadcastExchange has already been" + + s" transformed to columnar version but BHJ is determined as" + + s" non-transformable: ${plan.toString()}") + } + } + } + } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/BroadcastNestedLoopJoinMetricsUpdater.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/BroadcastNestedLoopJoinMetricsUpdater.scala index 2f8875cf9b74..b1414bf9727c 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/BroadcastNestedLoopJoinMetricsUpdater.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/BroadcastNestedLoopJoinMetricsUpdater.scala @@ -32,24 +32,6 @@ class BroadcastNestedLoopJoinMetricsUpdater(val metrics: Map[String, SQLMetric]) var currentIdx = operatorMetrics.metricsList.size() - 1 var totalTime = 0L - // build side pre projection - if (joinParams.buildPreProjectionNeeded) { - metrics("buildPreProjectionTime") += - (operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong - metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors - totalTime += operatorMetrics.metricsList.get(currentIdx).time - currentIdx -= 1 - } - - // stream side pre projection - if (joinParams.streamPreProjectionNeeded) { - metrics("streamPreProjectionTime") += - (operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong - metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors - totalTime += operatorMetrics.metricsList.get(currentIdx).time - currentIdx -= 1 - } - // update fillingRightJoinSideTime MetricsUtil .getAllProcessorList(operatorMetrics.metricsList.get(currentIdx)) @@ -76,6 +58,8 @@ class BroadcastNestedLoopJoinMetricsUpdater(val metrics: Map[String, SQLMetric]) } if (processor.name.equalsIgnoreCase("FilterTransform")) { metrics("conditionTime") += (processor.time / 1000L).toLong + metrics("numOutputRows") += processor.outputRows - processor.inputRows + metrics("outputBytes") += processor.outputBytes - processor.inputBytes } if (processor.name.equalsIgnoreCase("JoiningTransform")) { metrics("probeTime") += (processor.time / 1000L).toLong diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala index ed2e5c10e6c6..4a732785c1b3 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala @@ -56,14 +56,11 @@ abstract class GlutenClickHouseTPCDSAbstractSuite Seq("q" + "%d".format(queryNum)) } val noFallBack = queryNum match { - case i - if i == 10 || i == 16 || i == 35 || i == 45 || i == 77 || - i == 94 => + case i if i == 10 || i == 16 || i == 35 || i == 45 || i == 94 => // Q10 BroadcastHashJoin, ExistenceJoin // Q16 ShuffledHashJoin, NOT condition // Q35 BroadcastHashJoin, ExistenceJoin // Q45 BroadcastHashJoin, ExistenceJoin - // Q77 CartesianProduct // Q94 BroadcastHashJoin, LeftSemi, NOT condition (false, false) case j if j == 38 || j == 87 => @@ -73,6 +70,9 @@ abstract class GlutenClickHouseTPCDSAbstractSuite } else { (false, true) } + case q77 if q77 == 77 && !isAqe => + // Q77 CartesianProduct + (false, false) case other => (true, false) } sqlNums.map((_, noFallBack._1, noFallBack._2)) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index b0d3e1bdb866..4005d9f2e02a 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -1743,7 +1743,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | on t0.a = t1.a | ) t3 | )""".stripMargin - compareResultsAgainstVanillaSpark(sql1, true, { _ => }, false) + compareResultsAgainstVanillaSpark(sql1, true, { _ => }) val sql2 = """ @@ -1764,7 +1764,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | on t0.a = t1.a | ) t3 | )""".stripMargin - compareResultsAgainstVanillaSpark(sql2, true, { _ => }, false) + compareResultsAgainstVanillaSpark(sql2, true, { _ => }) val sql3 = """ @@ -1785,7 +1785,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | on t0.a = t1.a | ) t3 | )""".stripMargin - compareResultsAgainstVanillaSpark(sql3, true, { _ => }, false) + compareResultsAgainstVanillaSpark(sql3, true, { _ => }) val sql4 = """ @@ -1806,7 +1806,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | on t0.a = t1.a | ) t3 | )""".stripMargin - compareResultsAgainstVanillaSpark(sql4, true, { _ => }, false) + compareResultsAgainstVanillaSpark(sql4, true, { _ => }) val sql5 = """ @@ -1827,7 +1827,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | on t0.a = t1.a | ) t3 | )""".stripMargin - compareResultsAgainstVanillaSpark(sql5, true, { _ => }, false) + compareResultsAgainstVanillaSpark(sql5, true, { _ => }) } test("GLUTEN-1874 not null in one stream") { diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index 9c1089a35bea..e4ed39f46676 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -511,8 +511,6 @@ object VeloxBackendSettings extends BackendSettingsApi { override def supportCartesianProductExec(): Boolean = true - override def supportBroadcastNestedLoopJoinExec(): Boolean = true - override def supportSampleExec(): Boolean = true override def supportColumnarArrowUdf(): Boolean = true diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 1889ed519fa0..d417dab53649 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -51,6 +51,7 @@ #include #include #include +#include #include #include #include @@ -60,7 +61,6 @@ #include #include #include -#include #include #include #include @@ -1075,4 +1075,53 @@ UInt64 MemoryUtil::getMemoryRSS() return rss * sysconf(_SC_PAGESIZE); } + +void JoinUtil::reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols) +{ + ActionsDAGPtr project = std::make_shared(plan.getCurrentDataStream().header.getNamesAndTypesList()); + NamesWithAliases project_cols; + for (const auto & col : cols) + { + project_cols.emplace_back(NameWithAlias(col, col)); + } + project->project(project_cols); + QueryPlanStepPtr project_step = std::make_unique(plan.getCurrentDataStream(), project); + project_step->setStepDescription("Reorder Join Output"); + plan.addStep(std::move(project_step)); +} + +std::pair JoinUtil::getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type) +{ + switch (join_type) + { + case substrait::JoinRel_JoinType_JOIN_TYPE_INNER: + return {DB::JoinKind::Inner, DB::JoinStrictness::All}; + case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: + return {DB::JoinKind::Left, DB::JoinStrictness::Semi}; + case substrait::JoinRel_JoinType_JOIN_TYPE_ANTI: + return {DB::JoinKind::Left, DB::JoinStrictness::Anti}; + case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: + return {DB::JoinKind::Left, DB::JoinStrictness::All}; + case substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT: + return {DB::JoinKind::Right, DB::JoinStrictness::All}; + case substrait::JoinRel_JoinType_JOIN_TYPE_OUTER: + return {DB::JoinKind::Full, DB::JoinStrictness::All}; + default: + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); + } +} + +std::pair JoinUtil::getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type) +{ + switch (join_type) + { + case substrait::CrossRel_JoinType_JOIN_TYPE_INNER: + case substrait::CrossRel_JoinType_JOIN_TYPE_LEFT: + case substrait::CrossRel_JoinType_JOIN_TYPE_OUTER: + return {DB::JoinKind::Cross, DB::JoinStrictness::All}; + default: + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); + } +} + } diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index ff25415c0cbb..938ca9d11489 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -16,6 +16,7 @@ * limitations under the License. */ #pragma once + #include #include #include @@ -25,6 +26,8 @@ #include #include #include +#include +#include #include namespace DB @@ -302,4 +305,12 @@ class ConcurrentDeque mutable std::mutex mtx; }; +class JoinUtil +{ +public: + static void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols); + static std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type); + static std::pair getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type); +}; + } diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp index 3f3c7e6c32aa..f0c9612dc567 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp @@ -97,7 +97,7 @@ std::shared_ptr buildJoin( DB::ReadBuffer & input, jlong row_count, const std::string & join_keys, - substrait::JoinRel_JoinType join_type, + jint join_type, bool has_mixed_join_condition, const std::string & named_struct) { @@ -109,7 +109,11 @@ std::shared_ptr buildJoin( DB::JoinKind kind; DB::JoinStrictness strictness; - std::tie(kind, strictness) = getJoinKindAndStrictness(join_type); + if (key.starts_with("BuiltBNLJBroadcastTable-")) + std::tie(kind, strictness) = JoinUtil::getCrossJoinKindAndStrictness(static_cast(join_type)); + else + std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast(join_type)); + substrait::NamedStruct substrait_struct; substrait_struct.ParseFromString(named_struct); diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h index 9a6837e35a0a..3d2e67f9df10 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h @@ -35,7 +35,7 @@ std::shared_ptr buildJoin( DB::ReadBuffer & input, jlong row_count, const std::string & join_keys, - substrait::JoinRel_JoinType join_type, + jint join_type, bool has_mixed_join_condition, const std::string & named_struct); void cleanBuildHashTable(const std::string & hash_table_id, jlong instance); diff --git a/cpp-ch/local-engine/Parser/CrossRelParser.cpp b/cpp-ch/local-engine/Parser/CrossRelParser.cpp index 071588a9f26e..5abf1eb408b0 100644 --- a/cpp-ch/local-engine/Parser/CrossRelParser.cpp +++ b/cpp-ch/local-engine/Parser/CrossRelParser.cpp @@ -15,12 +15,12 @@ * limitations under the License. */ #include "CrossRelParser.h" + #include #include #include -#include #include -#include +#include #include #include #include @@ -30,8 +30,7 @@ #include #include #include - -#include +#include #include @@ -45,24 +44,17 @@ namespace ErrorCodes } } -struct JoinOptimizationInfo -{ - bool is_broadcast = false; - bool is_smj = false; - bool is_null_aware_anti_join = false; - bool is_existence_join = false; - std::string storage_join_key; -}; - using namespace DB; -String parseJoinOptimizationInfos(const substrait::CrossRel & join) + + +namespace local_engine +{ +String parseCrossJoinOptimizationInfos(const substrait::CrossRel & join) { google::protobuf::StringValue optimization; optimization.ParseFromString(join.advanced_extension().optimization().value()); - JoinOptimizationInfo info; - auto a = optimization.value(); String storage_join_key; ReadBufferFromString in(optimization.value()); assertString("JoinParameters:", in); @@ -71,49 +63,13 @@ String parseJoinOptimizationInfos(const substrait::CrossRel & join) return storage_join_key; } -void reorderJoinOutput2(DB::QueryPlan & plan, DB::Names cols) -{ - ActionsDAGPtr project = std::make_shared(plan.getCurrentDataStream().header.getNamesAndTypesList()); - NamesWithAliases project_cols; - for (const auto & col : cols) - { - project_cols.emplace_back(NameWithAlias(col, col)); - } - project->project(project_cols); - QueryPlanStepPtr project_step = std::make_unique(plan.getCurrentDataStream(), project); - project_step->setStepDescription("Reorder Join Output"); - plan.addStep(std::move(project_step)); -} - -namespace local_engine -{ - -std::pair getJoinKindAndStrictness2(substrait::CrossRel_JoinType join_type) -{ - switch (join_type) - { - case substrait::CrossRel_JoinType_JOIN_TYPE_INNER: - case substrait::CrossRel_JoinType_JOIN_TYPE_LEFT: - case substrait::CrossRel_JoinType_JOIN_TYPE_OUTER: - return {DB::JoinKind::Cross, DB::JoinStrictness::All}; - // case substrait::CrossRel_JoinType_JOIN_TYPE_LEFT: - // return {DB::JoinKind::Left, DB::JoinStrictness::All}; - // - // case substrait::CrossRel_JoinType_JOIN_TYPE_RIGHT: - // return {DB::JoinKind::Right, DB::JoinStrictness::All}; - // case substrait::CrossRel_JoinType_JOIN_TYPE_OUTER: - // return {DB::JoinKind::Full, DB::JoinStrictness::All}; - default: - throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); - } -} -std::shared_ptr createDefaultTableJoin2(substrait::CrossRel_JoinType join_type) +std::shared_ptr createCrossTableJoin(substrait::CrossRel_JoinType join_type) { auto & global_context = SerializedPlanParser::global_context; auto table_join = std::make_shared( global_context->getSettings(), global_context->getGlobalTemporaryVolume(), global_context->getTempDataOnDisk()); - std::pair kind_and_strictness = getJoinKindAndStrictness2(join_type); + std::pair kind_and_strictness = JoinUtil::getCrossJoinKindAndStrictness(join_type); table_join->setKind(kind_and_strictness.first); table_join->setStrictness(kind_and_strictness.second); return table_join; @@ -154,68 +110,6 @@ DB::QueryPlanPtr CrossRelParser::parseOp(const substrait::Rel & rel, std::list CrossRelParser::extractTableSidesFromExpression(const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header) -{ - std::unordered_set table_sides; - if (expr.has_scalar_function()) - { - for (const auto & arg : expr.scalar_function().arguments()) - { - auto table_sides_from_arg = extractTableSidesFromExpression(arg.value(), left_header, right_header); - table_sides.insert(table_sides_from_arg.begin(), table_sides_from_arg.end()); - } - } - else if (expr.has_selection() && expr.selection().has_direct_reference() && expr.selection().direct_reference().has_struct_field()) - { - auto pos = expr.selection().direct_reference().struct_field().field(); - if (pos < left_header.columns()) - { - table_sides.insert(DB::JoinTableSide::Left); - } - else - { - table_sides.insert(DB::JoinTableSide::Right); - } - } - else if (expr.has_singular_or_list()) - { - auto child_table_sides = extractTableSidesFromExpression(expr.singular_or_list().value(), left_header, right_header); - table_sides.insert(child_table_sides.begin(), child_table_sides.end()); - for (const auto & option : expr.singular_or_list().options()) - { - child_table_sides = extractTableSidesFromExpression(option, left_header, right_header); - table_sides.insert(child_table_sides.begin(), child_table_sides.end()); - } - } - else if (expr.has_cast()) - { - auto child_table_sides = extractTableSidesFromExpression(expr.cast().input(), left_header, right_header); - table_sides.insert(child_table_sides.begin(), child_table_sides.end()); - } - else if (expr.has_if_then()) - { - for (const auto & if_child : expr.if_then().ifs()) - { - auto child_table_sides = extractTableSidesFromExpression(if_child.if_(), left_header, right_header); - table_sides.insert(child_table_sides.begin(), child_table_sides.end()); - child_table_sides = extractTableSidesFromExpression(if_child.then(), left_header, right_header); - table_sides.insert(child_table_sides.begin(), child_table_sides.end()); - } - auto child_table_sides = extractTableSidesFromExpression(expr.if_then().else_(), left_header, right_header); - table_sides.insert(child_table_sides.begin(), child_table_sides.end()); - } - else if (expr.has_literal()) - { - // nothing - } - else - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Illegal expression '{}'", expr.DebugString()); - } - return table_sides; -} - - void CrossRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & right, const StorageJoinFromReadBuffer & storage_join) { /// To support mixed join conditions, we must make sure that the column names in the right be the same as @@ -266,23 +160,20 @@ void CrossRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & rig DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right) { - auto storage_join_key = parseJoinOptimizationInfos(join); + auto storage_join_key = parseCrossJoinOptimizationInfos(join); auto storage_join = BroadCastJoinBuilder::getJoin(storage_join_key) ; renamePlanColumns(*left, *right, *storage_join); - auto table_join = createDefaultTableJoin2(join.type()); + auto table_join = createCrossTableJoin(join.type()); DB::Block right_header_before_convert_step = right->getCurrentDataStream().header; addConvertStep(*table_join, *left, *right); // Add a check to find error easily. - if (storage_join) + if(!blocksHaveEqualStructure(right_header_before_convert_step, right->getCurrentDataStream().header)) { - if(!blocksHaveEqualStructure(right_header_before_convert_step, right->getCurrentDataStream().header)) - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}", - left->getCurrentDataStream().header.dumpNames(), - right_header_before_convert_step.dumpNames(), - right->getCurrentDataStream().header.dumpNames()); - } + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}", + left->getCurrentDataStream().header.dumpNames(), + right_header_before_convert_step.dumpNames(), + right->getCurrentDataStream().header.dumpNames()); } Names after_join_names; @@ -295,7 +186,6 @@ DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB: auto right_header = right->getCurrentDataStream().header; QueryPlanPtr query_plan; - // applyJoinFilter(*table_join, join, *left, *right, true); table_join->addDisjunct(); auto broadcast_hash_join = storage_join->getJoinLocked(table_join, context); // table_join->resetKeys(); @@ -309,7 +199,7 @@ DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB: extra_plan_holder.emplace_back(std::move(right)); addPostFilter(*query_plan, join); - reorderJoinOutput2(*query_plan, after_join_names); + JoinUtil::reorderJoinOutput(*query_plan, after_join_names); return query_plan; } @@ -339,102 +229,6 @@ void CrossRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait:: query_plan.addStep(std::move(filter_step)); } -bool CrossRelParser::applyJoinFilter( - DB::TableJoin & table_join, const substrait::CrossRel & join_rel, DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition) -{ - if (!join_rel.has_expression()) - return true; - const auto & expr = join_rel.expression(); - - const auto & left_header = left.getCurrentDataStream().header; - const auto & right_header = right.getCurrentDataStream().header; - ColumnsWithTypeAndName mixed_columns; - std::unordered_set added_column_name; - for (const auto & col : left_header.getColumnsWithTypeAndName()) - { - mixed_columns.emplace_back(col); - added_column_name.insert(col.name); - } - for (const auto & col : right_header.getColumnsWithTypeAndName()) - { - const auto & renamed_col_name = table_join.renamedRightColumnNameWithAlias(col.name); - if (added_column_name.find(col.name) != added_column_name.end()) - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Right column's name conflict with left column: {}", col.name); - mixed_columns.emplace_back(col); - added_column_name.insert(col.name); - } - DB::Block mixed_header(mixed_columns); - - auto table_sides = extractTableSidesFromExpression(expr, left_header, right_header); - - auto get_input_expressions = [](const DB::Block & header) - { - std::vector exprs; - for (size_t i = 0; i < header.columns(); ++i) - { - substrait::Expression expr; - expr.mutable_selection()->mutable_direct_reference()->mutable_struct_field()->set_field(i); - exprs.emplace_back(expr); - } - return exprs; - }; - - /// If the columns in the expression are all from one table, use analyzer_left_filter_condition_column_name - /// and analyzer_left_filter_condition_column_name to filt the join result data. It requires to build the filter - /// column at first. - /// If the columns in the expression are from both tables, use mixed_join_expression to filt the join result data. - /// the filter columns will be built inner the join step. - if (table_sides.size() == 1) - { - auto table_side = *table_sides.begin(); - if (table_side == DB::JoinTableSide::Left) - { - auto input_exprs = get_input_expressions(left_header); - input_exprs.push_back(expr); - auto actions_dag = expressionsToActionsDAG(input_exprs, left_header); - table_join.getClauses().back().analyzer_left_filter_condition_column_name = actions_dag->getOutputs().back()->result_name; - QueryPlanStepPtr before_join_step = std::make_unique(left.getCurrentDataStream(), actions_dag); - before_join_step->setStepDescription("Before JOIN LEFT"); - steps.emplace_back(before_join_step.get()); - left.addStep(std::move(before_join_step)); - } - else - { - /// since the field reference in expr is the index of left_header ++ right_header, so we use - /// mixed_header to build the actions_dag - auto input_exprs = get_input_expressions(mixed_header); - input_exprs.push_back(expr); - auto actions_dag = expressionsToActionsDAG(input_exprs, mixed_header); - - /// clear unused columns in actions_dag - for (const auto & col : left_header.getColumnsWithTypeAndName()) - { - actions_dag->removeUnusedResult(col.name); - } - actions_dag->removeUnusedActions(); - - table_join.getClauses().back().analyzer_right_filter_condition_column_name = actions_dag->getOutputs().back()->result_name; - QueryPlanStepPtr before_join_step = std::make_unique(right.getCurrentDataStream(), actions_dag); - before_join_step->setStepDescription("Before JOIN RIGHT"); - steps.emplace_back(before_join_step.get()); - right.addStep(std::move(before_join_step)); - } - } - else if (table_sides.size() == 2) - { - if (!allow_mixed_condition) - return false; - auto mixed_join_expressions_actions = expressionsToActionsDAG({expr}, mixed_header); - table_join.getMixedJoinExpression() - = std::make_shared(mixed_join_expressions_actions, ExpressionActionsSettings::fromContext(context)); - } - else - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not any table column is used in the join condition.\n{}", join_rel.DebugString()); - } - return true; -} - void CrossRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right) { /// If the columns name in right table is duplicated with left table, we need to rename the right table's columns. @@ -501,68 +295,6 @@ void CrossRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left } } -/// Join keys are collected from substrait::JoinRel::expression() which only contains the equal join conditions. -void CrossRelParser::collectJoinKeys( - TableJoin & table_join, const substrait::CrossRel & join_rel, const DB::Block & left_header, const DB::Block & right_header) -{ - if (!join_rel.has_expression()) - return; - const auto & expr = join_rel.expression(); - auto & join_clause = table_join.getClauses().back(); - std::list expressions_stack; - expressions_stack.push_back(&expr); - while (!expressions_stack.empty()) - { - /// Must handle the expressions in DF order. It matters in sort merge join. - const auto * current_expr = expressions_stack.back(); - expressions_stack.pop_back(); - if (!current_expr->has_scalar_function()) - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Function expression is expected"); - auto function_name = parseFunctionName(current_expr->scalar_function().function_reference(), current_expr->scalar_function()); - if (!function_name) - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid function expression"); - if (*function_name == "equals") - { - String left_key, right_key; - size_t left_pos = 0, right_pos = 0; - for (const auto & arg : current_expr->scalar_function().arguments()) - { - if (!arg.value().has_selection() || !arg.value().selection().has_direct_reference() - || !arg.value().selection().direct_reference().has_struct_field()) - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A column reference is expected"); - } - auto col_pos_ref = arg.value().selection().direct_reference().struct_field().field(); - if (col_pos_ref < left_header.columns()) - { - left_pos = col_pos_ref; - left_key = left_header.getByPosition(col_pos_ref).name; - } - else - { - right_pos = col_pos_ref - left_header.columns(); - right_key = right_header.getByPosition(col_pos_ref - left_header.columns()).name; - } - } - if (left_key.empty() || right_key.empty()) - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid key equal join condition"); - join_clause.addKey(left_key, right_key, false); - } - else if (*function_name == "and") - { - expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(1).value()); - expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(0).value()); - } - else if (*function_name == "not") - { - expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(0).value()); - } - else - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow function: {}", *function_name); - } - } -} void registerCrossRelParser(RelParserFactory & factory) { diff --git a/cpp-ch/local-engine/Parser/CrossRelParser.h b/cpp-ch/local-engine/Parser/CrossRelParser.h index 9766b4e91d24..f1cd60385e26 100644 --- a/cpp-ch/local-engine/Parser/CrossRelParser.h +++ b/cpp-ch/local-engine/Parser/CrossRelParser.h @@ -17,7 +17,6 @@ #pragma once #include -#include #include #include @@ -31,7 +30,6 @@ namespace local_engine class StorageJoinFromReadBuffer; -std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type); class CrossRelParser : public RelParser { @@ -56,12 +54,8 @@ class CrossRelParser : public RelParser void renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & right, const StorageJoinFromReadBuffer & storage_join); void addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right); void addPostFilter(DB::QueryPlan & query_plan, const substrait::CrossRel & join); - void collectJoinKeys( - TableJoin & table_join, const substrait::CrossRel & join_rel, const DB::Block & left_header, const DB::Block & right_header); bool applyJoinFilter( DB::TableJoin & table_join, const substrait::CrossRel & join_rel, DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition); - static std::unordered_set extractTableSidesFromExpression( - const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); }; } diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index a6a146954d6f..03734a2a9f0d 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -15,6 +15,7 @@ * limitations under the License. */ #include "JoinRelParser.h" + #include #include #include @@ -30,6 +31,7 @@ #include #include #include +#include #include #include @@ -98,51 +100,15 @@ JoinOptimizationInfo parseJoinOptimizationInfo(const substrait::JoinRel & join) return info; } - -void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols) -{ - ActionsDAGPtr project = std::make_shared(plan.getCurrentDataStream().header.getNamesAndTypesList()); - NamesWithAliases project_cols; - for (const auto & col : cols) - { - project_cols.emplace_back(NameWithAlias(col, col)); - } - project->project(project_cols); - QueryPlanStepPtr project_step = std::make_unique(plan.getCurrentDataStream(), project); - project_step->setStepDescription("Reorder Join Output"); - plan.addStep(std::move(project_step)); -} - namespace local_engine { - -std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type) -{ - switch (join_type) - { - case substrait::JoinRel_JoinType_JOIN_TYPE_INNER: - return {DB::JoinKind::Inner, DB::JoinStrictness::All}; - case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: - return {DB::JoinKind::Left, DB::JoinStrictness::Semi}; - case substrait::JoinRel_JoinType_JOIN_TYPE_ANTI: - return {DB::JoinKind::Left, DB::JoinStrictness::Anti}; - case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: - return {DB::JoinKind::Left, DB::JoinStrictness::All}; - case substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT: - return {DB::JoinKind::Right, DB::JoinStrictness::All}; - case substrait::JoinRel_JoinType_JOIN_TYPE_OUTER: - return {DB::JoinKind::Full, DB::JoinStrictness::All}; - default: - throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); - } -} std::shared_ptr createDefaultTableJoin(substrait::JoinRel_JoinType join_type) { auto & global_context = SerializedPlanParser::global_context; auto table_join = std::make_shared( global_context->getSettings(), global_context->getGlobalTemporaryVolume(), global_context->getTempDataOnDisk()); - std::pair kind_and_strictness = getJoinKindAndStrictness(join_type); + std::pair kind_and_strictness = JoinUtil::getJoinKindAndStrictness(join_type); table_join->setKind(kind_and_strictness.first); table_join->setStrictness(kind_and_strictness.second); return table_join; @@ -436,7 +402,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q query_plan = std::make_unique(); query_plan->unitePlans(std::move(join_step), {std::move(plans)}); } - reorderJoinOutput(*query_plan, after_join_names); + JoinUtil::reorderJoinOutput(*query_plan, after_join_names); return query_plan; } diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h b/cpp-ch/local-engine/Parser/JoinRelParser.h index c423f43908e7..15468b54b6f4 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.h +++ b/cpp-ch/local-engine/Parser/JoinRelParser.h @@ -31,8 +31,6 @@ namespace local_engine class StorageJoinFromReadBuffer; -std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type); - class JoinRelParser : public RelParser { public: diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 695fc8585538..627e6154cdcf 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -1121,13 +1121,12 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild const auto named_struct_a = local_engine::getByteArrayElementsSafe(env, named_struct); const std::string::size_type struct_size = named_struct_a.length(); std::string struct_string{reinterpret_cast(named_struct_a.elems()), struct_size}; - const auto join_type = static_cast(join_type_); const jsize length = env->GetArrayLength(in); local_engine::ReadBufferFromByteArray read_buffer_from_java_array(in, length); DB::CompressedReadBuffer input(read_buffer_from_java_array); local_engine::configureCompressedReadBuffer(input); const auto * obj = make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin( - hash_table_id, input, row_count_, join_key, join_type, has_mixed_join_condition, struct_string)); + hash_table_id, input, row_count_, join_key, join_type_, has_mixed_join_condition, struct_string)); return obj->instance(); LOCAL_ENGINE_JNI_METHOD_END(env, 0) } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index 8ddcc7b7f93e..b6bccc480601 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -144,7 +144,10 @@ trait BackendSettingsApi { def supportCartesianProductExec(): Boolean = false - def supportBroadcastNestedLoopJoinExec(): Boolean = true + def supportBroadcastNestedJoinJoinType: JoinType => Boolean = { + case _: InnerLike | LeftOuter | RightOuter => true + case _ => false + } def supportSampleExec(): Boolean = false diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index 241682ad867a..ef25b39ab9bd 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater @@ -24,13 +25,13 @@ import org.apache.gluten.utils.SubstraitUtil import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.BaseJoinExec import org.apache.spark.sql.execution.metric.SQLMetric -import com.google.protobuf.Any +import com.google.protobuf.{Any, StringValue} import io.substrait.proto.CrossRel abstract class BroadcastNestedLoopJoinExecTransformer( @@ -50,11 +51,12 @@ abstract class BroadcastNestedLoopJoinExecTransformer( private lazy val substraitJoinType: CrossRel.JoinType = SubstraitUtil.toCrossRelSubstrait(joinType) - private lazy val buildTableId: String = "BuildTable-" + buildPlan.id + // Unique ID for builded table + lazy val buildBroadcastTableId: String = "BuiltBNLJBroadcastTable-" + buildPlan.id // Hint substrait to switch the left and right, // since we assume always build right side in substrait. - protected lazy val needSwitchChildren: Boolean = buildSide match { + private lazy val needSwitchChildren: Boolean = buildSide match { case BuildLeft => true case BuildRight => false } @@ -80,7 +82,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer( left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output - case LeftSemi => // LeftExistence(_) + case LeftExistence(_) => left.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) @@ -108,7 +110,19 @@ abstract class BroadcastNestedLoopJoinExecTransformer( } } - def genJoinParameters(): Any = Any.getDefaultInstance + def genJoinParameters(): Any = { + // for ch + val joinParametersStr = new StringBuffer("JoinParameters:") + joinParametersStr + .append("buildHashTableId=") + .append(buildBroadcastTableId) + .append("\n") + val message = StringValue + .newBuilder() + .setValue(joinParametersStr.toString) + .build() + BackendsApiManager.getTransformerApiInstance.packPBMessage(message) + } override protected def doTransform(context: SubstraitContext): TransformContext = { val streamedPlanContext = streamedPlan.asInstanceOf[TransformSupport].transform(context) @@ -159,18 +173,36 @@ abstract class BroadcastNestedLoopJoinExecTransformer( inputBuildOutput) } + def validateJoinTypeAndBuildSide(): ValidationResult = { + joinType match { + case _: InnerLike | LeftOuter | RightOuter => ValidationResult.ok + case _ => + ValidationResult.notOk( + s"Broadcast Nested Loop join is not supported join type $joinType in this backend") + } + + (joinType, buildSide) match { + case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => + ValidationResult.notOk(s"$joinType join is not supported with $buildSide") + case _ => ValidationResult.ok // continue + } + } + override protected def doValidateInternal(): ValidationResult = { - if (!BackendsApiManager.getSettings.supportBroadcastNestedLoopJoinExec()) { - return ValidationResult.notOk("Broadcast Nested Loop join is not supported in this backend") + if (!GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled) { + return ValidationResult.notOk( + s"Config ${GlutenConfig.BROADCAST_NESTED_LOOP_JOIN_TRANSFORMER_ENABLED.key} not enabled") } + if (substraitJoinType == CrossRel.JoinType.UNRECOGNIZED) { return ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") } - (joinType, buildSide) match { - case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => - return ValidationResult.notOk(s"$joinType join is not supported with $buildSide") - case _ => // continue + + val validateResult = validateJoinTypeAndBuildSide() + if (!validateResult.isValid) { + return validateResult } + val substraitContext = new SubstraitContext val crossRel = JoinUtils.createCrossRel( diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala index 801a7d22c49f..03b2b66b09b3 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicApplier.scala @@ -19,7 +19,7 @@ package org.apache.gluten.extension.columnar.heuristic import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.columnar._ -import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveBroadcastNestedLoopJoin, RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} +import org.apache.gluten.extension.columnar.MiscColumnarRules.{RemoveGlutenTableCacheColumnarToRow, RemoveTopmostColumnarToRow, RewriteSubqueryBroadcast, TransformPreOverrides} import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions} import org.apache.gluten.extension.columnar.util.AdaptiveContext @@ -108,7 +108,6 @@ class HeuristicApplier(session: SparkSession) (_: SparkSession) => FallbackEmptySchemaRelation(), (spark: SparkSession) => MergeTwoPhasesHashBaseAggregate(spark), (_: SparkSession) => RewriteSparkPlanRulesManager(), - (_: SparkSession) => RemoveBroadcastNestedLoopJoin(), (_: SparkSession) => AddFallbackTagRule() ) ::: List((_: SparkSession) => TransformPreOverrides()) ::: diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala index 959bf808aba4..b6236ae9a536 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala @@ -147,8 +147,6 @@ object Validators { case p: SortAggregateExec if !settings.replaceSortAggWithHashAgg => fail(p) case p: CartesianProductExec if !settings.supportCartesianProductExec() => fail(p) - case p: BroadcastNestedLoopJoinExec if !settings.supportBroadcastNestedLoopJoinExec() => - fail(p) case p: TakeOrderedAndProjectExec if !settings.supportColumnarShuffleExec() => fail(p) case _ => pass() } diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala index 9671c7a6bca2..e8e7ce06feaf 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala @@ -48,6 +48,10 @@ object SubstraitUtil { // the left and right relations are exchanged and the // join type is reverted. CrossRel.JoinType.JOIN_TYPE_LEFT + case LeftSemi => + CrossRel.JoinType.JOIN_TYPE_LEFT_SEMI + case FullOuter => + CrossRel.JoinType.JOIN_TYPE_OUTER case _ => CrossRel.JoinType.UNRECOGNIZED } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala index df1c87cb0ccc..4da7a2f6f11a 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.Statistics -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, IdentityBroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf @@ -134,13 +134,6 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) } override protected def doValidateInternal(): ValidationResult = { - // CH backend does not support IdentityBroadcastMode used in BNLJ - if ( - mode == IdentityBroadcastMode && !BackendsApiManager.getSettings - .supportBroadcastNestedLoopJoinExec() - ) { - return ValidationResult.notOk("This backend does not support IdentityBroadcastMode and BNLJ") - } BackendsApiManager.getValidatorApiInstance .doSchemaValidate(schema) .map { From 2ac3fb90e4f2d7827604d18b423562cfb58cb020 Mon Sep 17 00:00:00 2001 From: loneylee Date: Mon, 8 Jul 2024 19:39:46 +0800 Subject: [PATCH 3/6] fix ci fix ci fix ci 2 fix velox error --- .../backendsapi/clickhouse/CHBackend.scala | 5 -- ...oadcastNestedLoopJoinExecTransformer.scala | 22 ++++++- .../FallbackBroadcaseHashJoinRules.scala | 48 +++++++++++++++ .../backendsapi/velox/VeloxBackend.scala | 2 + cpp-ch/local-engine/Common/CHUtil.cpp | 2 - cpp-ch/local-engine/Common/CHUtil.h | 1 + .../Join/BroadCastJoinBuilder.cpp | 8 ++- cpp-ch/local-engine/Parser/CrossRelParser.cpp | 59 ++++++++++--------- .../backendsapi/BackendSettingsApi.scala | 5 +- ...oadcastNestedLoopJoinExecTransformer.scala | 26 +++----- .../apache/gluten/execution/JoinUtils.scala | 11 +--- 11 files changed, 122 insertions(+), 67 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index b0c077dcb9db..4c9edd57c930 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -299,8 +298,4 @@ object CHBackendSettings extends BackendSettingsApi with Logging { override def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = true - override def supportBroadcastNestedJoinJoinType: JoinType => Boolean = { - case _: InnerLike | LeftSemi | FullOuter => true - case _ => false - } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala index c408a223784b..35be8ee0b13e 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala @@ -16,17 +16,20 @@ */ package org.apache.gluten.execution +import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.ValidationResult import org.apache.spark.rdd.RDD import org.apache.spark.rpc.GlutenDriverEndpoint import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.BuildSide -import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType} +import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftSemi} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch +import com.google.protobuf.{Any, StringValue} + case class CHBroadcastNestedLoopJoinExecTransformer( left: SparkPlan, right: SparkPlan, @@ -56,7 +59,6 @@ case class CHBroadcastNestedLoopJoinExecTransformer( val context = BroadCastHashJoinContext(Seq.empty, joinType, false, buildPlan.output, buildBroadcastTableId) val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context) - // FIXME: Do we have to make build side a RDD? streamedRDD :+ broadcastRDD } @@ -77,11 +79,25 @@ case class CHBroadcastNestedLoopJoinExecTransformer( res } + override def genJoinParameters(): Any = { + // for ch + val joinParametersStr = new StringBuffer("JoinParameters:") + joinParametersStr + .append("buildHashTableId=") + .append(buildBroadcastTableId) + .append("\n") + val message = StringValue + .newBuilder() + .setValue(joinParametersStr.toString) + .build() + BackendsApiManager.getTransformerApiInstance.packPBMessage(message) + } + override def validateJoinTypeAndBuildSide(): ValidationResult = { joinType match { case _: InnerLike => case _ => - if (condition.isDefined) { + if (joinType == LeftSemi || condition.isDefined) { return ValidationResult.notOk( s"Broadcast Nested Loop join is not supported join type $joinType with conditions") } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala index f7d9a6dbe7ea..c7f9b47de642 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala @@ -89,10 +89,58 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend // Skip. This might be the case that the exchange was already // executed in earlier stage } + case bnlj: BroadcastNestedLoopJoinExec => applyBNLJPrepQueryStage(bnlj) case _ => } plan } + + private def applyBNLJPrepQueryStage(bnlj: BroadcastNestedLoopJoinExec) = { + val buildSidePlan = bnlj.buildSide match { + case BuildLeft => bnlj.left + case BuildRight => bnlj.right + } + val maybeExchange = buildSidePlan.find { + case BroadcastExchangeExec(_, _) => true + case _ => false + } + maybeExchange match { + case Some(exchange @ BroadcastExchangeExec(mode, child)) => + val isTransformable = + if ( + !GlutenConfig.getConf.enableColumnarBroadcastExchange || + !GlutenConfig.getConf.enableColumnarBroadcastJoin + ) { + ValidationResult.notOk( + "columnar broadcast exchange is disabled or " + + "columnar broadcast join is disabled") + } else { + if (FallbackTags.nonEmpty(bnlj)) { + ValidationResult.notOk("broadcast join is already tagged as not transformable") + } else { + val transformer = BackendsApiManager.getSparkPlanExecApiInstance + .genBroadcastNestedLoopJoinExecTransformer( + bnlj.left, + bnlj.right, + bnlj.buildSide, + bnlj.joinType, + bnlj.condition) + val isTransformable = transformer.doValidate() + if (isTransformable.isValid) { + val exchangeTransformer = ColumnarBroadcastExchangeExec(mode, child) + exchangeTransformer.doValidate() + } else { + isTransformable + } + } + } + FallbackTags.add(bnlj, isTransformable) + FallbackTags.add(exchange, isTransformable) + case _ => + // Skip. This might be the case that the exchange was already + // executed in earlier stage + } + } } // For similar purpose with FallbackBroadcastHashJoinPrepQueryStage, executed during applying diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index e4ed39f46676..9c1089a35bea 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -511,6 +511,8 @@ object VeloxBackendSettings extends BackendSettingsApi { override def supportCartesianProductExec(): Boolean = true + override def supportBroadcastNestedLoopJoinExec(): Boolean = true + override def supportSampleExec(): Boolean = true override def supportColumnarArrowUdf(): Boolean = true diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index d417dab53649..850c863d0bfd 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -84,8 +84,6 @@ extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; namespace local_engine { -constexpr auto VIRTUAL_ROW_COUNT_COLUMN = "__VIRTUAL_ROW_COUNT_COLUMN__"; - namespace fs = std::filesystem; DB::Block BlockUtil::buildRowCountHeader() diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 938ca9d11489..65764af7d148 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -50,6 +50,7 @@ static const std::unordered_set LONG_VALUE_SETTINGS{ class BlockUtil { public: + static constexpr auto VIRTUAL_ROW_COUNT_COLUMN = "__VIRTUAL_ROW_COUNT_COLUMN__"; static constexpr auto RIHGT_COLUMN_PREFIX = "broadcast_right_"; // Build a header block with a virtual column which will be diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp index f0c9612dc567..4d5eae6dc0b5 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp @@ -141,9 +141,15 @@ std::shared_ptr buildJoin( for (size_t i = 0; i < block.columns(); ++i) { const auto & column = block.getByPosition(i); - columns.emplace_back(BlockUtil::convertColumnAsNecessary(column, header.getByPosition(i))); if (only_one_column) + { + auto virtual_block = BlockUtil::buildRowCountBlock(column.column->size()).getColumnsWithTypeAndName(); + header = virtual_block; + columns.emplace_back(virtual_block.back()); break; + } + + columns.emplace_back(BlockUtil::convertColumnAsNecessary(column, header.getByPosition(i))); } DB::Block final_block(columns); diff --git a/cpp-ch/local-engine/Parser/CrossRelParser.cpp b/cpp-ch/local-engine/Parser/CrossRelParser.cpp index 5abf1eb408b0..63e81f4be312 100644 --- a/cpp-ch/local-engine/Parser/CrossRelParser.cpp +++ b/cpp-ch/local-engine/Parser/CrossRelParser.cpp @@ -112,19 +112,21 @@ DB::QueryPlanPtr CrossRelParser::parseOp(const substrait::Rel & rel, std::list 0 && right_ori_header[0].name != BlockUtil::VIRTUAL_ROW_COUNT_COLUMN) { - QueryPlanStepPtr project_step = std::make_unique(right.getCurrentDataStream(), project); - project_step->setStepDescription("Rename Broadcast Table Name"); - steps.emplace_back(project_step.get()); - right.addStep(std::move(project_step)); + project = ActionsDAG::makeConvertingActions( + right_ori_header, storage_join.getRightSampleBlock().getColumnsWithTypeAndName(), ActionsDAG::MatchColumnsMode::Position); + if (project) + { + QueryPlanStepPtr project_step = std::make_unique(right.getCurrentDataStream(), project); + project_step->setStepDescription("Rename Broadcast Table Name"); + steps.emplace_back(project_step.get()); + right.addStep(std::move(project_step)); + } } /// If the columns name in right table is duplicated with left table, we need to rename the left table's columns, @@ -134,27 +136,22 @@ void CrossRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & rig const auto & right_header = right.getCurrentDataStream().header; auto left_prefix = getUniqueName("left"); for (const auto & col : left.getCurrentDataStream().header) - { if (right_header.has(col.name)) - { new_left_cols.emplace_back(col.column, col.type, left_prefix + col.name); - } else - { new_left_cols.emplace_back(col.column, col.type, col.name); - } - } - project = ActionsDAG::makeConvertingActions( - left.getCurrentDataStream().header.getColumnsWithTypeAndName(), - new_left_cols, - ActionsDAG::MatchColumnsMode::Position); - - if (project) + auto left_header = left.getCurrentDataStream().header.getColumnsWithTypeAndName(); + if (left_header.size() > 0 && left_header[0].name != BlockUtil::VIRTUAL_ROW_COUNT_COLUMN) { - QueryPlanStepPtr project_step = std::make_unique(left.getCurrentDataStream(), project); - project_step->setStepDescription("Rename Left Table Name for broadcast join"); - steps.emplace_back(project_step.get()); - left.addStep(std::move(project_step)); + project = ActionsDAG::makeConvertingActions(left_header, new_left_cols, ActionsDAG::MatchColumnsMode::Position); + + if (project) + { + QueryPlanStepPtr project_step = std::make_unique(left.getCurrentDataStream(), project); + project_step->setStepDescription("Rename Left Table Name for broadcast join"); + steps.emplace_back(project_step.get()); + left.addStep(std::move(project_step)); + } } } @@ -199,7 +196,15 @@ DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB: extra_plan_holder.emplace_back(std::move(right)); addPostFilter(*query_plan, join); - JoinUtil::reorderJoinOutput(*query_plan, after_join_names); + Names cols; + for (auto after_join_name : after_join_names) + { + if (BlockUtil::VIRTUAL_ROW_COUNT_COLUMN == after_join_name) + continue; + + cols.emplace_back(after_join_name); + } + JoinUtil::reorderJoinOutput(*query_plan, cols); return query_plan; } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index b6bccc480601..8ddcc7b7f93e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -144,10 +144,7 @@ trait BackendSettingsApi { def supportCartesianProductExec(): Boolean = false - def supportBroadcastNestedJoinJoinType: JoinType => Boolean = { - case _: InnerLike | LeftOuter | RightOuter => true - case _ => false - } + def supportBroadcastNestedLoopJoinExec(): Boolean = true def supportSampleExec(): Boolean = false diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index ef25b39ab9bd..9f0ec22eebae 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.BaseJoinExec import org.apache.spark.sql.execution.metric.SQLMetric -import com.google.protobuf.{Any, StringValue} +import com.google.protobuf.Any import io.substrait.proto.CrossRel abstract class BroadcastNestedLoopJoinExecTransformer( @@ -110,19 +110,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer( } } - def genJoinParameters(): Any = { - // for ch - val joinParametersStr = new StringBuffer("JoinParameters:") - joinParametersStr - .append("buildHashTableId=") - .append(buildBroadcastTableId) - .append("\n") - val message = StringValue - .newBuilder() - .setValue(joinParametersStr.toString) - .build() - BackendsApiManager.getTransformerApiInstance.packPBMessage(message) - } + def genJoinParameters(): Any = Any.getDefaultInstance override protected def doTransform(context: SubstraitContext): TransformContext = { val streamedPlanContext = streamedPlan.asInstanceOf[TransformSupport].transform(context) @@ -156,8 +144,8 @@ abstract class BroadcastNestedLoopJoinExecTransformer( val projectRelPostJoinRel = JoinUtils.createProjectRelPostJoinRel( needSwitchChildren, joinType, - inputStreamedOutput, - inputBuildOutput, + streamedPlan.output, + buildPlan.output, context, operatorId, crossRel, @@ -174,13 +162,17 @@ abstract class BroadcastNestedLoopJoinExecTransformer( } def validateJoinTypeAndBuildSide(): ValidationResult = { - joinType match { + val result = joinType match { case _: InnerLike | LeftOuter | RightOuter => ValidationResult.ok case _ => ValidationResult.notOk( s"Broadcast Nested Loop join is not supported join type $joinType in this backend") } + if (!result.isValid) { + return result + } + (joinType, buildSide) match { case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => ValidationResult.notOk(s"$joinType join is not supported with $buildSide") diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala index 1371c10dafcc..9dd73800e29b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala @@ -133,15 +133,10 @@ object JoinUtils { def createJoinExtensionNode( joinParameters: Any, - output: Seq[Attribute], - validation: Boolean = false): AdvancedExtensionNode = { + output: Seq[Attribute]): AdvancedExtensionNode = { // Use field [optimization] in a extension node // to send some join parameters through Substrait plan. - val enhancement = if (validation) { - createEnhancement(output) - } else { - null - } + val enhancement = createEnhancement(output) ExtensionBuilder.makeAdvancedExtension(joinParameters, enhancement) } @@ -352,7 +347,7 @@ object JoinUtils { .doTransform(substraitContext.registeredFunction) } val extensionNode = - createJoinExtensionNode(joinParameters, inputStreamedOutput ++ inputBuildOutput, validation) + createJoinExtensionNode(joinParameters, inputStreamedOutput ++ inputBuildOutput) RelBuilder.makeCrossRel( inputStreamedRelNode, From 9c7424ee4d8d9595eb7c15ac1faefdcc1af3b34b Mon Sep 17 00:00:00 2001 From: loneylee Date: Wed, 10 Jul 2024 14:12:43 +0800 Subject: [PATCH 4/6] fix ci error --- cpp-ch/local-engine/Parser/CrossRelParser.cpp | 17 +++++++---------- ...BroadcastNestedLoopJoinExecTransformer.scala | 2 +- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/cpp-ch/local-engine/Parser/CrossRelParser.cpp b/cpp-ch/local-engine/Parser/CrossRelParser.cpp index 63e81f4be312..ea898640146b 100644 --- a/cpp-ch/local-engine/Parser/CrossRelParser.cpp +++ b/cpp-ch/local-engine/Parser/CrossRelParser.cpp @@ -141,17 +141,14 @@ void CrossRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & rig else new_left_cols.emplace_back(col.column, col.type, col.name); auto left_header = left.getCurrentDataStream().header.getColumnsWithTypeAndName(); - if (left_header.size() > 0 && left_header[0].name != BlockUtil::VIRTUAL_ROW_COUNT_COLUMN) - { - project = ActionsDAG::makeConvertingActions(left_header, new_left_cols, ActionsDAG::MatchColumnsMode::Position); + project = ActionsDAG::makeConvertingActions(left_header, new_left_cols, ActionsDAG::MatchColumnsMode::Position); - if (project) - { - QueryPlanStepPtr project_step = std::make_unique(left.getCurrentDataStream(), project); - project_step->setStepDescription("Rename Left Table Name for broadcast join"); - steps.emplace_back(project_step.get()); - left.addStep(std::move(project_step)); - } + if (project) + { + QueryPlanStepPtr project_step = std::make_unique(left.getCurrentDataStream(), project); + project_step->setStepDescription("Rename Left Table Name for broadcast join"); + steps.emplace_back(project_step.get()); + left.addStep(std::move(project_step)); } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index 9f0ec22eebae..39019fc0751e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -166,7 +166,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer( case _: InnerLike | LeftOuter | RightOuter => ValidationResult.ok case _ => ValidationResult.notOk( - s"Broadcast Nested Loop join is not supported join type $joinType in this backend") + s"$joinType join is not supported with BroadcastNestedLoopJoin") } if (!result.isValid) { From bb57982ec9399733ed00ec279791877187e44ae1 Mon Sep 17 00:00:00 2001 From: loneylee Date: Wed, 10 Jul 2024 14:50:44 +0800 Subject: [PATCH 5/6] fix checkstyle --- .../execution/BroadcastNestedLoopJoinExecTransformer.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index 39019fc0751e..b90c1ad8b6e7 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -165,8 +165,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer( val result = joinType match { case _: InnerLike | LeftOuter | RightOuter => ValidationResult.ok case _ => - ValidationResult.notOk( - s"$joinType join is not supported with BroadcastNestedLoopJoin") + ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") } if (!result.isValid) { From 680c8688de775b7543219fbb4f365e2070267692 Mon Sep 17 00:00:00 2001 From: loneylee Date: Wed, 10 Jul 2024 16:05:09 +0800 Subject: [PATCH 6/6] fix ci --- .../scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala index 6860d6a12958..e724cf31c689 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala @@ -110,7 +110,7 @@ class GlutenFallbackSuite extends GlutenSQLTestsTrait with AdaptiveSparkPlanHelp execution.get.fallbackNodeToReason.head._2 .contains("FullOuter join is not supported with BroadcastNestedLoopJoin")) } else { - assert(execution.get.numFallbackNodes == 2) + assert(execution.get.numFallbackNodes == 0) } }