diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java index 9cb49b6a2d30..27725998feeb 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java @@ -74,12 +74,20 @@ public static long build( return converter.genColumnNameWithExprId(attr); }) .collect(Collectors.joining(",")); + + int joinType; + if (broadCastContext.buildHashTableId().startsWith("BuiltBNLJBroadcastTable-")) { + joinType = SubstraitUtil.toCrossRelSubstrait(broadCastContext.joinType()).ordinal(); + } else { + joinType = SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(); + } + return nativeBuild( broadCastContext.buildHashTableId(), batches, rowCount, joinKey, - SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(), + joinType, broadCastContext.hasMixedFiltCondition(), toNameStruct(output).toByteArray()); } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index d369b8c1626f..4c9edd57c930 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -297,4 +297,5 @@ object CHBackendSettings extends BackendSettingsApi with Logging { } override def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = true + } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala index a5fb4a1853e8..5465e9b60b67 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala @@ -348,16 +348,29 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil { metrics: Map[String, SQLMetric]): MetricsUpdater = new HashJoinMetricsUpdater(metrics) override def genNestedLoopJoinTransformerMetrics( - sparkContext: SparkContext): Map[String, SQLMetric] = { - throw new UnsupportedOperationException( - s"NestedLoopJoinTransformer metrics update is not supported in CH backend") - } + sparkContext: SparkContext): Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "outputVectors" -> SQLMetrics.createMetric(sparkContext, "number of output vectors"), + "outputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of output bytes"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "inputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of input bytes"), + "extraTime" -> SQLMetrics.createTimingMetric(sparkContext, "extra operators time"), + "inputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of waiting for data"), + "outputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of waiting for output"), + "postProjectTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time of postProjection"), + "probeTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time of probe"), + "totalTime" -> SQLMetrics.createTimingMetric(sparkContext, "time"), + "fillingRightJoinSideTime" -> SQLMetrics.createTimingMetric( + sparkContext, + "filling right join side time"), + "conditionTime" -> SQLMetrics.createTimingMetric(sparkContext, "join condition time") + ) override def genNestedLoopJoinTransformerMetricsUpdater( - metrics: Map[String, SQLMetric]): MetricsUpdater = { - throw new UnsupportedOperationException( - s"NestedLoopJoinTransformer metrics update is not supported in CH backend") - } + metrics: Map[String, SQLMetric]): MetricsUpdater = new BroadcastNestedLoopJoinMetricsUpdater( + metrics) override def genSampleTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = { throw new UnsupportedOperationException( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index f5feade886b9..fc2ebda397e6 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -373,8 +373,13 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { buildSide: BuildSide, joinType: JoinType, condition: Option[Expression]): BroadcastNestedLoopJoinExecTransformer = - throw new GlutenNotSupportException( - "BroadcastNestedLoopJoinExecTransformer is not supported in ch backend.") + CHBroadcastNestedLoopJoinExecTransformer( + left, + right, + buildSide, + joinType, + condition + ) override def genSampleExecTransformer( lowerBound: Double, @@ -460,16 +465,23 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { child: SparkPlan, numOutputRows: SQLMetric, dataSize: SQLMetric): BuildSideRelation = { - val hashedRelationBroadcastMode = mode.asInstanceOf[HashedRelationBroadcastMode] + + val buildKeys: Seq[Expression] = mode match { + case mode1: HashedRelationBroadcastMode => + mode1.key + case _ => + // IdentityBroadcastMode + Seq.empty + } + val (newChild, newOutput, newBuildKeys) = if ( - hashedRelationBroadcastMode.key + buildKeys .forall(k => k.isInstanceOf[AttributeReference] || k.isInstanceOf[BoundReference]) ) { (child, child.output, Seq.empty[Expression]) } else { // pre projection in case of expression join keys - val buildKeys = hashedRelationBroadcastMode.key val appendedProjections = new ArrayBuffer[NamedExpression]() val preProjectionBuildKeys = buildKeys.zipWithIndex.map { case (e, idx) => diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala new file mode 100644 index 000000000000..35be8ee0b13e --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.extension.ValidationResult + +import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.GlutenDriverEndpoint +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftSemi} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.joins.BuildSideRelation +import org.apache.spark.sql.vectorized.ColumnarBatch + +import com.google.protobuf.{Any, StringValue} + +case class CHBroadcastNestedLoopJoinExecTransformer( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) + extends BroadcastNestedLoopJoinExecTransformer( + left, + right, + buildSide, + joinType, + condition + ) { + + override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = { + val streamedRDD = getColumnarInputRDDs(streamedPlan) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + if (executionId != null) { + GlutenDriverEndpoint.collectResources(executionId, buildBroadcastTableId) + } else { + logWarning( + s"Can't not trace broadcast table data $buildBroadcastTableId" + + s" because execution id is null." + + s" Will clean up until expire time.") + } + val broadcast = buildPlan.executeBroadcast[BuildSideRelation]() + val context = + BroadCastHashJoinContext(Seq.empty, joinType, false, buildPlan.output, buildBroadcastTableId) + val broadcastRDD = CHBroadcastBuildSideRDD(sparkContext, broadcast, context) + streamedRDD :+ broadcastRDD + } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, + newRight: SparkPlan): CHBroadcastNestedLoopJoinExecTransformer = + copy(left = newLeft, right = newRight) + + def isMixedCondition(cond: Option[Expression]): Boolean = { + val res = if (cond.isDefined) { + val leftOutputSet = left.outputSet + val rightOutputSet = right.outputSet + val allReferences = cond.get.references + !(allReferences.subsetOf(leftOutputSet) || allReferences.subsetOf(rightOutputSet)) + } else { + false + } + res + } + + override def genJoinParameters(): Any = { + // for ch + val joinParametersStr = new StringBuffer("JoinParameters:") + joinParametersStr + .append("buildHashTableId=") + .append(buildBroadcastTableId) + .append("\n") + val message = StringValue + .newBuilder() + .setValue(joinParametersStr.toString) + .build() + BackendsApiManager.getTransformerApiInstance.packPBMessage(message) + } + + override def validateJoinTypeAndBuildSide(): ValidationResult = { + joinType match { + case _: InnerLike => + case _ => + if (joinType == LeftSemi || condition.isDefined) { + return ValidationResult.notOk( + s"Broadcast Nested Loop join is not supported join type $joinType with conditions") + } + } + + ValidationResult.ok + } + +} diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala index 59c2d6494bdb..c7f9b47de642 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec} -import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} import scala.util.control.Breaks.{break, breakable} @@ -89,10 +89,58 @@ case class FallbackBroadcastHashJoinPrepQueryStage(session: SparkSession) extend // Skip. This might be the case that the exchange was already // executed in earlier stage } + case bnlj: BroadcastNestedLoopJoinExec => applyBNLJPrepQueryStage(bnlj) case _ => } plan } + + private def applyBNLJPrepQueryStage(bnlj: BroadcastNestedLoopJoinExec) = { + val buildSidePlan = bnlj.buildSide match { + case BuildLeft => bnlj.left + case BuildRight => bnlj.right + } + val maybeExchange = buildSidePlan.find { + case BroadcastExchangeExec(_, _) => true + case _ => false + } + maybeExchange match { + case Some(exchange @ BroadcastExchangeExec(mode, child)) => + val isTransformable = + if ( + !GlutenConfig.getConf.enableColumnarBroadcastExchange || + !GlutenConfig.getConf.enableColumnarBroadcastJoin + ) { + ValidationResult.notOk( + "columnar broadcast exchange is disabled or " + + "columnar broadcast join is disabled") + } else { + if (FallbackTags.nonEmpty(bnlj)) { + ValidationResult.notOk("broadcast join is already tagged as not transformable") + } else { + val transformer = BackendsApiManager.getSparkPlanExecApiInstance + .genBroadcastNestedLoopJoinExecTransformer( + bnlj.left, + bnlj.right, + bnlj.buildSide, + bnlj.joinType, + bnlj.condition) + val isTransformable = transformer.doValidate() + if (isTransformable.isValid) { + val exchangeTransformer = ColumnarBroadcastExchangeExec(mode, child) + exchangeTransformer.doValidate() + } else { + isTransformable + } + } + } + FallbackTags.add(bnlj, isTransformable) + FallbackTags.add(exchange, isTransformable) + case _ => + // Skip. This might be the case that the exchange was already + // executed in earlier stage + } + } } // For similar purpose with FallbackBroadcastHashJoinPrepQueryStage, executed during applying @@ -103,6 +151,10 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl GlutenConfig.getConf.enableColumnarBroadcastJoin && GlutenConfig.getConf.enableColumnarBroadcastExchange + private val enableColumnarBroadcastNestedLoopJoin: Boolean = + GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled && + GlutenConfig.getConf.enableColumnarBroadcastExchange + override def apply(plan: SparkPlan): SparkPlan = { plan.foreachUp { p => @@ -138,63 +190,9 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl case BuildRight => bhj.right } - val maybeExchange = buildSidePlan - .find { - case BroadcastExchangeExec(_, _) => true - case _ => false - } - .map(_.asInstanceOf[BroadcastExchangeExec]) - - maybeExchange match { - case Some(exchange @ BroadcastExchangeExec(mode, child)) => - isBhjTransformable.tagOnFallback(bhj) - if (!isBhjTransformable.isValid) { - FallbackTags.add(exchange, isBhjTransformable) - } - case None => - // we are in AQE, find the hidden exchange - // FIXME did we consider the case that AQE: OFF && Reuse: ON ? - var maybeHiddenExchange: Option[BroadcastExchangeLike] = None - breakable { - buildSidePlan.foreach { - case e: BroadcastExchangeLike => - maybeHiddenExchange = Some(e) - break - case t: BroadcastQueryStageExec => - t.plan.foreach { - case e2: BroadcastExchangeLike => - maybeHiddenExchange = Some(e2) - break - case r: ReusedExchangeExec => - r.child match { - case e2: BroadcastExchangeLike => - maybeHiddenExchange = Some(e2) - break - case _ => - } - case _ => - } - case _ => - } - } - // restriction to force the hidden exchange to be found - val exchange = maybeHiddenExchange.get - // to conform to the underlying exchange's type, columnar or vanilla - exchange match { - case BroadcastExchangeExec(mode, child) => - FallbackTags.add( - bhj, - "it's a materialized broadcast exchange or reused broadcast exchange") - case ColumnarBroadcastExchangeExec(mode, child) => - if (!isBhjTransformable.isValid) { - throw new IllegalStateException( - s"BroadcastExchange has already been" + - s" transformed to columnar version but BHJ is determined as" + - s" non-transformable: ${bhj.toString()}") - } - } - } + preTagBroadcastExchangeFallback(bhj, buildSidePlan, isBhjTransformable) } + case bnlj: BroadcastNestedLoopJoinExec => applyBNLJFallback(bnlj) case _ => } } catch { @@ -207,4 +205,88 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl } plan } + + private def applyBNLJFallback(bnlj: BroadcastNestedLoopJoinExec) = { + if (!enableColumnarBroadcastNestedLoopJoin) { + FallbackTags.add(bnlj, "columnar BroadcastJoin is not enabled in BroadcastNestedLoopJoinExec") + } + + val transformer = BackendsApiManager.getSparkPlanExecApiInstance + .genBroadcastNestedLoopJoinExecTransformer( + bnlj.left, + bnlj.right, + bnlj.buildSide, + bnlj.joinType, + bnlj.condition) + + val isBNLJTransformable = transformer.doValidate() + val buildSidePlan = bnlj.buildSide match { + case BuildLeft => bnlj.left + case BuildRight => bnlj.right + } + + preTagBroadcastExchangeFallback(bnlj, buildSidePlan, isBNLJTransformable) + } + + private def preTagBroadcastExchangeFallback( + plan: SparkPlan, + buildSidePlan: SparkPlan, + isTransformable: ValidationResult): Unit = { + val maybeExchange = buildSidePlan + .find { + case BroadcastExchangeExec(_, _) => true + case _ => false + } + .map(_.asInstanceOf[BroadcastExchangeExec]) + + maybeExchange match { + case Some(exchange @ BroadcastExchangeExec(_, _)) => + isTransformable.tagOnFallback(plan) + if (!isTransformable.isValid) { + FallbackTags.add(exchange, isTransformable) + } + case None => + // we are in AQE, find the hidden exchange + // FIXME did we consider the case that AQE: OFF && Reuse: ON ? + var maybeHiddenExchange: Option[BroadcastExchangeLike] = None + breakable { + buildSidePlan.foreach { + case e: BroadcastExchangeLike => + maybeHiddenExchange = Some(e) + break + case t: BroadcastQueryStageExec => + t.plan.foreach { + case e2: BroadcastExchangeLike => + maybeHiddenExchange = Some(e2) + break + case r: ReusedExchangeExec => + r.child match { + case e2: BroadcastExchangeLike => + maybeHiddenExchange = Some(e2) + break + case _ => + } + case _ => + } + case _ => + } + } + // restriction to force the hidden exchange to be found + val exchange = maybeHiddenExchange.get + // to conform to the underlying exchange's type, columnar or vanilla + exchange match { + case BroadcastExchangeExec(mode, child) => + FallbackTags.add( + plan, + "it's a materialized broadcast exchange or reused broadcast exchange") + case ColumnarBroadcastExchangeExec(mode, child) => + if (!isTransformable.isValid) { + throw new IllegalStateException( + s"BroadcastExchange has already been" + + s" transformed to columnar version but BHJ is determined as" + + s" non-transformable: ${plan.toString()}") + } + } + } + } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/BroadcastNestedLoopJoinMetricsUpdater.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/BroadcastNestedLoopJoinMetricsUpdater.scala new file mode 100644 index 000000000000..b1414bf9727c --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/metrics/BroadcastNestedLoopJoinMetricsUpdater.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.metrics + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.metric.SQLMetric + +class BroadcastNestedLoopJoinMetricsUpdater(val metrics: Map[String, SQLMetric]) + extends MetricsUpdater + with Logging { + + override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = { + try { + if (opMetrics != null) { + val operatorMetrics = opMetrics.asInstanceOf[OperatorMetrics] + if (!operatorMetrics.metricsList.isEmpty && operatorMetrics.joinParams != null) { + val joinParams = operatorMetrics.joinParams + var currentIdx = operatorMetrics.metricsList.size() - 1 + var totalTime = 0L + + // update fillingRightJoinSideTime + MetricsUtil + .getAllProcessorList(operatorMetrics.metricsList.get(currentIdx)) + .foreach( + processor => { + if (processor.name.equalsIgnoreCase("FillingRightJoinSide")) { + metrics("fillingRightJoinSideTime") += (processor.time / 1000L).toLong + } + }) + + // joining + val joinMetricsData = operatorMetrics.metricsList.get(currentIdx) + metrics("outputVectors") += joinMetricsData.outputVectors + metrics("inputWaitTime") += (joinMetricsData.inputWaitTime / 1000L).toLong + metrics("outputWaitTime") += (joinMetricsData.outputWaitTime / 1000L).toLong + totalTime += joinMetricsData.time + + MetricsUtil + .getAllProcessorList(joinMetricsData) + .foreach( + processor => { + if (processor.name.equalsIgnoreCase("FillingRightJoinSide")) { + metrics("fillingRightJoinSideTime") += (processor.time / 1000L).toLong + } + if (processor.name.equalsIgnoreCase("FilterTransform")) { + metrics("conditionTime") += (processor.time / 1000L).toLong + metrics("numOutputRows") += processor.outputRows - processor.inputRows + metrics("outputBytes") += processor.outputBytes - processor.inputBytes + } + if (processor.name.equalsIgnoreCase("JoiningTransform")) { + metrics("probeTime") += (processor.time / 1000L).toLong + } + if ( + !BroadcastNestedLoopJoinMetricsUpdater.INCLUDING_PROCESSORS.contains( + processor.name) + ) { + metrics("extraTime") += (processor.time / 1000L).toLong + } + if ( + BroadcastNestedLoopJoinMetricsUpdater.CH_PLAN_NODE_NAME.contains(processor.name) + ) { + metrics("numOutputRows") += processor.outputRows + metrics("outputBytes") += processor.outputBytes + metrics("numInputRows") += processor.inputRows + metrics("inputBytes") += processor.inputBytes + } + }) + + currentIdx -= 1 + + // post projection + if (joinParams.postProjectionNeeded) { + metrics("postProjectTime") += + (operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong + metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors + totalTime += operatorMetrics.metricsList.get(currentIdx).time + currentIdx -= 1 + } + metrics("totalTime") += (totalTime / 1000L).toLong + } + } + } catch { + case e: Exception => + logError(s"Updating native metrics failed due to ${e.getCause}.") + throw e + } + } +} + +object BroadcastNestedLoopJoinMetricsUpdater { + val INCLUDING_PROCESSORS = Array("JoiningTransform", "FillingRightJoinSide", "FilterTransform") + val CH_PLAN_NODE_NAME = Array("JoiningTransform") +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala index bcdc1f5ef514..4a732785c1b3 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCDSAbstractSuite.scala @@ -56,17 +56,11 @@ abstract class GlutenClickHouseTPCDSAbstractSuite Seq("q" + "%d".format(queryNum)) } val noFallBack = queryNum match { - case i - if i == 10 || i == 16 || i == 28 || i == 35 || i == 45 || i == 77 || - i == 88 || i == 90 || i == 94 => + case i if i == 10 || i == 16 || i == 35 || i == 45 || i == 94 => // Q10 BroadcastHashJoin, ExistenceJoin // Q16 ShuffledHashJoin, NOT condition - // Q28 BroadcastNestedLoopJoin // Q35 BroadcastHashJoin, ExistenceJoin // Q45 BroadcastHashJoin, ExistenceJoin - // Q77 CartesianProduct - // Q88 BroadcastNestedLoopJoin - // Q90 BroadcastNestedLoopJoin // Q94 BroadcastHashJoin, LeftSemi, NOT condition (false, false) case j if j == 38 || j == 87 => @@ -76,6 +70,9 @@ abstract class GlutenClickHouseTPCDSAbstractSuite } else { (false, true) } + case q77 if q77 == 77 && !isAqe => + // Q77 CartesianProduct + (false, false) case other => (true, false) } sqlNums.map((_, noFallBack._1, noFallBack._2)) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index b0d3e1bdb866..4005d9f2e02a 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -1743,7 +1743,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | on t0.a = t1.a | ) t3 | )""".stripMargin - compareResultsAgainstVanillaSpark(sql1, true, { _ => }, false) + compareResultsAgainstVanillaSpark(sql1, true, { _ => }) val sql2 = """ @@ -1764,7 +1764,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | on t0.a = t1.a | ) t3 | )""".stripMargin - compareResultsAgainstVanillaSpark(sql2, true, { _ => }, false) + compareResultsAgainstVanillaSpark(sql2, true, { _ => }) val sql3 = """ @@ -1785,7 +1785,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | on t0.a = t1.a | ) t3 | )""".stripMargin - compareResultsAgainstVanillaSpark(sql3, true, { _ => }, false) + compareResultsAgainstVanillaSpark(sql3, true, { _ => }) val sql4 = """ @@ -1806,7 +1806,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | on t0.a = t1.a | ) t3 | )""".stripMargin - compareResultsAgainstVanillaSpark(sql4, true, { _ => }, false) + compareResultsAgainstVanillaSpark(sql4, true, { _ => }) val sql5 = """ @@ -1827,7 +1827,7 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr | on t0.a = t1.a | ) t3 | )""".stripMargin - compareResultsAgainstVanillaSpark(sql5, true, { _ => }, false) + compareResultsAgainstVanillaSpark(sql5, true, { _ => }) } test("GLUTEN-1874 not null in one stream") { diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 800039f1d262..850c863d0bfd 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -51,6 +51,7 @@ #include #include #include +#include #include #include #include @@ -60,7 +61,6 @@ #include #include #include -#include #include #include #include @@ -84,8 +84,6 @@ extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; namespace local_engine { -constexpr auto VIRTUAL_ROW_COUNT_COLUMN = "__VIRTUAL_ROW_COUNT_COLUMN__"; - namespace fs = std::filesystem; DB::Block BlockUtil::buildRowCountHeader() @@ -128,6 +126,27 @@ DB::Block BlockUtil::buildHeader(const DB::NamesAndTypesList & names_types_list) return DB::Block(cols); } +/// The column names may be different in two blocks. +/// and the nullability also could be different, with TPCDS-Q1 as an example. +DB::ColumnWithTypeAndName +BlockUtil::convertColumnAsNecessary(const DB::ColumnWithTypeAndName & column, const DB::ColumnWithTypeAndName & sample_column) +{ + if (sample_column.type->equals(*column.type)) + return {column.column, column.type, sample_column.name}; + else if (sample_column.type->isNullable() && !column.type->isNullable() && DB::removeNullable(sample_column.type)->equals(*column.type)) + { + auto nullable_column = column; + DB::JoinCommon::convertColumnToNullable(nullable_column); + return {nullable_column.column, sample_column.type, sample_column.name}; + } + else + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, + "Columns have different types. original:{} expected:{}", + column.dumpStructure(), + sample_column.dumpStructure()); +} + /** * There is a special case with which we need be careful. In spark, struct/map/list are always * wrapped in Nullable, but this should not happen in clickhouse. @@ -1054,4 +1073,53 @@ UInt64 MemoryUtil::getMemoryRSS() return rss * sysconf(_SC_PAGESIZE); } + +void JoinUtil::reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols) +{ + ActionsDAGPtr project = std::make_shared(plan.getCurrentDataStream().header.getNamesAndTypesList()); + NamesWithAliases project_cols; + for (const auto & col : cols) + { + project_cols.emplace_back(NameWithAlias(col, col)); + } + project->project(project_cols); + QueryPlanStepPtr project_step = std::make_unique(plan.getCurrentDataStream(), project); + project_step->setStepDescription("Reorder Join Output"); + plan.addStep(std::move(project_step)); +} + +std::pair JoinUtil::getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type) +{ + switch (join_type) + { + case substrait::JoinRel_JoinType_JOIN_TYPE_INNER: + return {DB::JoinKind::Inner, DB::JoinStrictness::All}; + case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: + return {DB::JoinKind::Left, DB::JoinStrictness::Semi}; + case substrait::JoinRel_JoinType_JOIN_TYPE_ANTI: + return {DB::JoinKind::Left, DB::JoinStrictness::Anti}; + case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: + return {DB::JoinKind::Left, DB::JoinStrictness::All}; + case substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT: + return {DB::JoinKind::Right, DB::JoinStrictness::All}; + case substrait::JoinRel_JoinType_JOIN_TYPE_OUTER: + return {DB::JoinKind::Full, DB::JoinStrictness::All}; + default: + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); + } +} + +std::pair JoinUtil::getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type) +{ + switch (join_type) + { + case substrait::CrossRel_JoinType_JOIN_TYPE_INNER: + case substrait::CrossRel_JoinType_JOIN_TYPE_LEFT: + case substrait::CrossRel_JoinType_JOIN_TYPE_OUTER: + return {DB::JoinKind::Cross, DB::JoinStrictness::All}; + default: + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); + } +} + } diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 3ac0f63ce10b..65764af7d148 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -16,6 +16,7 @@ * limitations under the License. */ #pragma once + #include #include #include @@ -25,6 +26,8 @@ #include #include #include +#include +#include #include namespace DB @@ -47,6 +50,9 @@ static const std::unordered_set LONG_VALUE_SETTINGS{ class BlockUtil { public: + static constexpr auto VIRTUAL_ROW_COUNT_COLUMN = "__VIRTUAL_ROW_COUNT_COLUMN__"; + static constexpr auto RIHGT_COLUMN_PREFIX = "broadcast_right_"; + // Build a header block with a virtual column which will be // use to indicate the number of rows in a block. // Commonly seen in the following quries: @@ -72,6 +78,10 @@ class BlockUtil const std::unordered_set & columns_to_skip_flatten = {}); static DB::Block concatenateBlocksMemoryEfficiently(std::vector && blocks); + + /// The column names may be different in two blocks. + /// and the nullability also could be different, with TPCDS-Q1 as an example. + static DB::ColumnWithTypeAndName convertColumnAsNecessary(const DB::ColumnWithTypeAndName & column, const DB::ColumnWithTypeAndName & sample_column); }; class PODArrayUtil @@ -296,4 +306,12 @@ class ConcurrentDeque mutable std::mutex mtx; }; +class JoinUtil +{ +public: + static void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols); + static std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type); + static std::pair getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type); +}; + } diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp index 1c79a00a7c4c..4d5eae6dc0b5 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp @@ -15,16 +15,18 @@ * limitations under the License. */ #include "BroadCastJoinBuilder.h" + #include #include #include #include #include +#include #include #include #include #include -#include +#include #include #include @@ -52,6 +54,20 @@ jlong callJavaGet(const std::string & id) return result; } +DB::Block resetBuildTableBlockName(Block & block, bool only_one = false) +{ + DB::ColumnsWithTypeAndName new_cols; + for (const auto & col : block) + { + // Add a prefix to avoid column name conflicts with left table. + new_cols.emplace_back(col.column, col.type, BlockUtil::RIHGT_COLUMN_PREFIX + col.name); + + if (only_one) + break; + } + return DB::Block(new_cols); +} + void cleanBuildHashTable(const std::string & hash_table_id, jlong instance) { /// Thread status holds raw pointer on query context, thus it always must be destroyed @@ -81,26 +97,71 @@ std::shared_ptr buildJoin( DB::ReadBuffer & input, jlong row_count, const std::string & join_keys, - substrait::JoinRel_JoinType join_type, + jint join_type, bool has_mixed_join_condition, const std::string & named_struct) { auto join_key_list = Poco::StringTokenizer(join_keys, ","); Names key_names; for (const auto & key_name : join_key_list) - key_names.emplace_back(key_name); + key_names.emplace_back(BlockUtil::RIHGT_COLUMN_PREFIX + key_name); + DB::JoinKind kind; DB::JoinStrictness strictness; - std::tie(kind, strictness) = getJoinKindAndStrictness(join_type); + if (key.starts_with("BuiltBNLJBroadcastTable-")) + std::tie(kind, strictness) = JoinUtil::getCrossJoinKindAndStrictness(static_cast(join_type)); + else + std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast(join_type)); + substrait::NamedStruct substrait_struct; substrait_struct.ParseFromString(named_struct); Block header = TypeParser::buildBlockFromNamedStruct(substrait_struct); + header = resetBuildTableBlockName(header); + + Blocks data; + { + bool header_empty = header.getNamesAndTypesList().empty(); + bool only_one_column = header_empty; + NativeReader block_stream(input); + ProfileInfo info; + while (Block block = block_stream.read()) + { + if (header_empty) + { + // In bnlj, buidside output maybe empty, + // we use buildside header only for loop + // Like: select count(*) from t1 left join t2 + header = resetBuildTableBlockName(block, true); + header_empty = false; + } + + DB::ColumnsWithTypeAndName columns; + for (size_t i = 0; i < block.columns(); ++i) + { + const auto & column = block.getByPosition(i); + if (only_one_column) + { + auto virtual_block = BlockUtil::buildRowCountBlock(column.column->size()).getColumnsWithTypeAndName(); + header = virtual_block; + columns.emplace_back(virtual_block.back()); + break; + } + + columns.emplace_back(BlockUtil::convertColumnAsNecessary(column, header.getByPosition(i))); + } + + DB::Block final_block(columns); + info.update(final_block); + data.emplace_back(std::move(final_block)); + } + } + ColumnsDescription columns_description(header.getNamesAndTypesList()); return make_shared( - input, + data, row_count, key_names, true, diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h index 9a6837e35a0a..3d2e67f9df10 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h @@ -35,7 +35,7 @@ std::shared_ptr buildJoin( DB::ReadBuffer & input, jlong row_count, const std::string & join_keys, - substrait::JoinRel_JoinType join_type, + jint join_type, bool has_mixed_join_condition, const std::string & named_struct); void cleanBuildHashTable(const std::string & hash_table_id, jlong instance); diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp index 326e11a84f81..2f5afd434b41 100644 --- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp +++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.cpp @@ -16,12 +16,10 @@ */ #include "StorageJoinFromReadBuffer.h" -#include #include #include #include -#include -#include +#include #include #include @@ -43,8 +41,6 @@ extern const int DEADLOCK_AVOIDED; using namespace DB; -constexpr auto RIHGT_COLUMN_PREFIX = "broadcast_right_"; - DB::Block rightSampleBlock(bool use_nulls, const StorageInMemoryMetadata & storage_metadata_, JoinKind kind) { DB::ColumnsWithTypeAndName new_cols; @@ -52,7 +48,7 @@ DB::Block rightSampleBlock(bool use_nulls, const StorageInMemoryMetadata & stora for (const auto & col : block) { // Add a prefix to avoid column name conflicts with left table. - new_cols.emplace_back(col.column, col.type, RIHGT_COLUMN_PREFIX + col.name); + new_cols.emplace_back(col.column, col.type, col.name); if (use_nulls && isLeftOrFull(kind)) { auto & new_col = new_cols.back(); @@ -66,7 +62,7 @@ namespace local_engine { StorageJoinFromReadBuffer::StorageJoinFromReadBuffer( - DB::ReadBuffer & in, + Blocks & data, size_t row_count_, const Names & key_names_, bool use_nulls_, @@ -77,7 +73,7 @@ StorageJoinFromReadBuffer::StorageJoinFromReadBuffer( const ConstraintsDescription & constraints, const String & comment, const bool overwrite_) - : key_names({}), use_nulls(use_nulls_), row_count(row_count_), overwrite(overwrite_) + : key_names(key_names_), use_nulls(use_nulls_), row_count(row_count_), overwrite(overwrite_) { storage_metadata.setColumns(columns); storage_metadata.setConstraints(constraints); @@ -86,74 +82,33 @@ StorageJoinFromReadBuffer::StorageJoinFromReadBuffer( for (const auto & key : key_names_) if (!storage_metadata.getColumns().hasPhysical(key)) throw Exception(ErrorCodes::NO_SUCH_COLUMN_IN_TABLE, "Key column ({}) does not exist in table declaration.", key); - for (const auto & name : key_names_) - key_names.push_back(RIHGT_COLUMN_PREFIX + name); auto table_join = std::make_shared(SizeLimits(), true, kind, strictness, key_names); - right_sample_block = rightSampleBlock(use_nulls, storage_metadata, table_join->kind()); - /// If there is mixed join conditions, need to build the hash join lazily, which rely on the real table join. - if (!has_mixed_join_condition) - buildJoin(in, right_sample_block, table_join); - else - collectAllInputs(in, right_sample_block); -} -/// The column names may be different in two blocks. -/// and the nullability also could be different, with TPCDS-Q1 as an example. -static DB::ColumnWithTypeAndName convertColumnAsNecessary(const DB::ColumnWithTypeAndName & column, const DB::ColumnWithTypeAndName & sample_column) -{ - if (sample_column.type->equals(*column.type)) - return {column.column, column.type, sample_column.name}; - else if ( - sample_column.type->isNullable() && !column.type->isNullable() - && DB::removeNullable(sample_column.type)->equals(*column.type)) + if (key_names.empty()) { - auto nullable_column = column; - DB::JoinCommon::convertColumnToNullable(nullable_column); - return {nullable_column.column, sample_column.type, sample_column.name}; + // For bnlj cross join, keys clauses should be empty. + table_join->resetKeys(); } + + right_sample_block = rightSampleBlock(use_nulls, storage_metadata, table_join->kind()); + /// If there is mixed join conditions, need to build the hash join lazily, which rely on the real table join. + if (!has_mixed_join_condition) + buildJoin(data, right_sample_block, table_join); else - throw DB::Exception( - DB::ErrorCodes::LOGICAL_ERROR, - "Columns have different types. original:{} expected:{}", - column.dumpStructure(), - sample_column.dumpStructure()); + collectAllInputs(data, right_sample_block); } -void StorageJoinFromReadBuffer::buildJoin(DB::ReadBuffer & in, const Block header, std::shared_ptr analyzed_join) +void StorageJoinFromReadBuffer::buildJoin(Blocks & data, const Block header, std::shared_ptr analyzed_join) { - local_engine::NativeReader block_stream(in); - ProfileInfo info; join = std::make_shared(analyzed_join, header, overwrite, row_count); - while (Block block = block_stream.read()) - { - DB::ColumnsWithTypeAndName columns; - for (size_t i = 0; i < block.columns(); ++i) - { - const auto & column = block.getByPosition(i); - columns.emplace_back(convertColumnAsNecessary(column, header.getByPosition(i))); - } - DB::Block final_block(columns); - info.update(final_block); - join->addBlockToJoin(final_block, true); - } + for (Block block : data) + join->addBlockToJoin(std::move(block), true); } -void StorageJoinFromReadBuffer::collectAllInputs(DB::ReadBuffer & in, const DB::Block header) +void StorageJoinFromReadBuffer::collectAllInputs(Blocks & data, const DB::Block) { - local_engine::NativeReader block_stream(in); - ProfileInfo info; - while (Block block = block_stream.read()) - { - DB::ColumnsWithTypeAndName columns; - for (size_t i = 0; i < block.columns(); ++i) - { - const auto & column = block.getByPosition(i); - columns.emplace_back(convertColumnAsNecessary(column, header.getByPosition(i))); - } - DB::Block final_block(columns); - info.update(final_block); - input_blocks.emplace_back(std::move(final_block)); - } + for (Block block : data) + input_blocks.emplace_back(std::move(block)); } void StorageJoinFromReadBuffer::buildJoinLazily(DB::Block header, std::shared_ptr analyzed_join) @@ -174,7 +129,7 @@ void StorageJoinFromReadBuffer::buildJoinLazily(DB::Block header, std::shared_pt for (size_t i = 0; i < block.columns(); ++i) { const auto & column = block.getByPosition(i); - columns.emplace_back(convertColumnAsNecessary(column, header.getByPosition(i))); + columns.emplace_back(BlockUtil::convertColumnAsNecessary(column, header.getByPosition(i))); } DB::Block final_block(columns); join->addBlockToJoin(final_block, true); diff --git a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h index ddefda69c30f..600210e66869 100644 --- a/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h +++ b/cpp-ch/local-engine/Join/StorageJoinFromReadBuffer.h @@ -35,7 +35,7 @@ class StorageJoinFromReadBuffer { public: StorageJoinFromReadBuffer( - DB::ReadBuffer & in_, + DB::Blocks & data, size_t row_count, const DB::Names & key_names_, bool use_nulls_, @@ -65,8 +65,8 @@ class StorageJoinFromReadBuffer std::shared_ptr join = nullptr; void readAllBlocksFromInput(DB::ReadBuffer & in); - void buildJoin(DB::ReadBuffer & in, const DB::Block header, std::shared_ptr analyzed_join); - void collectAllInputs(DB::ReadBuffer & in, const DB::Block header); + void buildJoin(DB::Blocks & data, const DB::Block header, std::shared_ptr analyzed_join); + void collectAllInputs(DB::Blocks & data, const DB::Block header); void buildJoinLazily(DB::Block header, std::shared_ptr analyzed_join); }; } diff --git a/cpp-ch/local-engine/Parser/CrossRelParser.cpp b/cpp-ch/local-engine/Parser/CrossRelParser.cpp new file mode 100644 index 000000000000..ea898640146b --- /dev/null +++ b/cpp-ch/local-engine/Parser/CrossRelParser.cpp @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "CrossRelParser.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int UNKNOWN_TYPE; + extern const int BAD_ARGUMENTS; +} +} + +using namespace DB; + + + + +namespace local_engine +{ +String parseCrossJoinOptimizationInfos(const substrait::CrossRel & join) +{ + google::protobuf::StringValue optimization; + optimization.ParseFromString(join.advanced_extension().optimization().value()); + String storage_join_key; + ReadBufferFromString in(optimization.value()); + assertString("JoinParameters:", in); + assertString("buildHashTableId=", in); + readString(storage_join_key, in); + return storage_join_key; +} + +std::shared_ptr createCrossTableJoin(substrait::CrossRel_JoinType join_type) +{ + auto & global_context = SerializedPlanParser::global_context; + auto table_join = std::make_shared( + global_context->getSettings(), global_context->getGlobalTemporaryVolume(), global_context->getTempDataOnDisk()); + + std::pair kind_and_strictness = JoinUtil::getCrossJoinKindAndStrictness(join_type); + table_join->setKind(kind_and_strictness.first); + table_join->setStrictness(kind_and_strictness.second); + return table_join; +} + +CrossRelParser::CrossRelParser(SerializedPlanParser * plan_paser_) + : RelParser(plan_paser_) + , function_mapping(plan_paser_->function_mapping) + , context(plan_paser_->context) + , extra_plan_holder(plan_paser_->extra_plan_holder) +{ +} + +DB::QueryPlanPtr +CrossRelParser::parse(DB::QueryPlanPtr /*query_plan*/, const substrait::Rel & /*rel*/, std::list & /*rel_stack_*/) +{ + throw Exception(ErrorCodes::LOGICAL_ERROR, "join node has 2 inputs, can't call parse()."); +} + +const substrait::Rel & CrossRelParser::getSingleInput(const substrait::Rel & /*rel*/) +{ + throw Exception(ErrorCodes::LOGICAL_ERROR, "join node has 2 inputs, can't call getSingleInput()."); +} + +DB::QueryPlanPtr CrossRelParser::parseOp(const substrait::Rel & rel, std::list & rel_stack) +{ + const auto & join = rel.cross(); + if (!join.has_left() || !join.has_right()) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "left table or right table is missing."); + } + + rel_stack.push_back(&rel); + auto left_plan = getPlanParser()->parseOp(join.left(), rel_stack); + auto right_plan = getPlanParser()->parseOp(join.right(), rel_stack); + rel_stack.pop_back(); + + return parseJoin(join, std::move(left_plan), std::move(right_plan)); +} + +void CrossRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & right, const StorageJoinFromReadBuffer & storage_join) +{ + ActionsDAGPtr project = nullptr; + /// To support mixed join conditions, we must make sure that the column names in the right be the same as + /// storage_join's right sample block. + auto right_ori_header = right.getCurrentDataStream().header.getColumnsWithTypeAndName(); + if (right_ori_header.size() > 0 && right_ori_header[0].name != BlockUtil::VIRTUAL_ROW_COUNT_COLUMN) + { + project = ActionsDAG::makeConvertingActions( + right_ori_header, storage_join.getRightSampleBlock().getColumnsWithTypeAndName(), ActionsDAG::MatchColumnsMode::Position); + if (project) + { + QueryPlanStepPtr project_step = std::make_unique(right.getCurrentDataStream(), project); + project_step->setStepDescription("Rename Broadcast Table Name"); + steps.emplace_back(project_step.get()); + right.addStep(std::move(project_step)); + } + } + + /// If the columns name in right table is duplicated with left table, we need to rename the left table's columns, + /// avoid the columns name in the right table be changed in `addConvertStep`. + /// This could happen in tpc-ds q44. + DB::ColumnsWithTypeAndName new_left_cols; + const auto & right_header = right.getCurrentDataStream().header; + auto left_prefix = getUniqueName("left"); + for (const auto & col : left.getCurrentDataStream().header) + if (right_header.has(col.name)) + new_left_cols.emplace_back(col.column, col.type, left_prefix + col.name); + else + new_left_cols.emplace_back(col.column, col.type, col.name); + auto left_header = left.getCurrentDataStream().header.getColumnsWithTypeAndName(); + project = ActionsDAG::makeConvertingActions(left_header, new_left_cols, ActionsDAG::MatchColumnsMode::Position); + + if (project) + { + QueryPlanStepPtr project_step = std::make_unique(left.getCurrentDataStream(), project); + project_step->setStepDescription("Rename Left Table Name for broadcast join"); + steps.emplace_back(project_step.get()); + left.addStep(std::move(project_step)); + } +} + +DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right) +{ + auto storage_join_key = parseCrossJoinOptimizationInfos(join); + auto storage_join = BroadCastJoinBuilder::getJoin(storage_join_key) ; + renamePlanColumns(*left, *right, *storage_join); + auto table_join = createCrossTableJoin(join.type()); + DB::Block right_header_before_convert_step = right->getCurrentDataStream().header; + addConvertStep(*table_join, *left, *right); + + // Add a check to find error easily. + if(!blocksHaveEqualStructure(right_header_before_convert_step, right->getCurrentDataStream().header)) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}", + left->getCurrentDataStream().header.dumpNames(), + right_header_before_convert_step.dumpNames(), + right->getCurrentDataStream().header.dumpNames()); + } + + Names after_join_names; + auto left_names = left->getCurrentDataStream().header.getNames(); + after_join_names.insert(after_join_names.end(), left_names.begin(), left_names.end()); + auto right_name = table_join->columnsFromJoinedTable().getNames(); + after_join_names.insert(after_join_names.end(), right_name.begin(), right_name.end()); + + auto left_header = left->getCurrentDataStream().header; + auto right_header = right->getCurrentDataStream().header; + + QueryPlanPtr query_plan; + table_join->addDisjunct(); + auto broadcast_hash_join = storage_join->getJoinLocked(table_join, context); + // table_join->resetKeys(); + QueryPlanStepPtr join_step = std::make_unique(left->getCurrentDataStream(), broadcast_hash_join, 8192); + + join_step->setStepDescription("STORAGE_JOIN"); + steps.emplace_back(join_step.get()); + left->addStep(std::move(join_step)); + query_plan = std::move(left); + /// hold right plan for profile + extra_plan_holder.emplace_back(std::move(right)); + + addPostFilter(*query_plan, join); + Names cols; + for (auto after_join_name : after_join_names) + { + if (BlockUtil::VIRTUAL_ROW_COUNT_COLUMN == after_join_name) + continue; + + cols.emplace_back(after_join_name); + } + JoinUtil::reorderJoinOutput(*query_plan, cols); + + return query_plan; +} + + +void CrossRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::CrossRel & join_rel) +{ + if (!join_rel.has_expression()) + return; + + auto expression = join_rel.expression(); + std::string filter_name; + auto actions_dag = std::make_shared(query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName()); + if (!expression.has_scalar_function()) + { + // It may be singular_or_list + auto * in_node = getPlanParser()->parseExpression(actions_dag, expression); + filter_name = in_node->result_name; + } + else + { + getPlanParser()->parseFunction(query_plan.getCurrentDataStream().header, expression, filter_name, actions_dag, true); + } + auto filter_step = std::make_unique(query_plan.getCurrentDataStream(), actions_dag, filter_name, true); + filter_step->setStepDescription("Post Join Filter"); + steps.emplace_back(filter_step.get()); + query_plan.addStep(std::move(filter_step)); +} + +void CrossRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right) +{ + /// If the columns name in right table is duplicated with left table, we need to rename the right table's columns. + NameSet left_columns_set; + for (const auto & col : left.getCurrentDataStream().header.getNames()) + left_columns_set.emplace(col); + table_join.setColumnsFromJoinedTable( + right.getCurrentDataStream().header.getNamesAndTypesList(), left_columns_set, getUniqueName("right") + "."); + + // fix right table key duplicate + NamesWithAliases right_table_alias; + for (size_t idx = 0; idx < table_join.columnsFromJoinedTable().size(); idx++) + { + auto origin_name = right.getCurrentDataStream().header.getByPosition(idx).name; + auto dedup_name = table_join.columnsFromJoinedTable().getNames().at(idx); + if (origin_name != dedup_name) + { + right_table_alias.emplace_back(NameWithAlias(origin_name, dedup_name)); + } + } + if (!right_table_alias.empty()) + { + ActionsDAGPtr rename_dag = std::make_shared(right.getCurrentDataStream().header.getNamesAndTypesList()); + auto original_right_columns = right.getCurrentDataStream().header; + for (const auto & column_alias : right_table_alias) + { + if (original_right_columns.has(column_alias.first)) + { + auto pos = original_right_columns.getPositionByName(column_alias.first); + const auto & alias = rename_dag->addAlias(*rename_dag->getInputs()[pos], column_alias.second); + rename_dag->getOutputs()[pos] = &alias; + } + } + + QueryPlanStepPtr project_step = std::make_unique(right.getCurrentDataStream(), rename_dag); + project_step->setStepDescription("Right Table Rename"); + steps.emplace_back(project_step.get()); + right.addStep(std::move(project_step)); + } + + for (const auto & column : table_join.columnsFromJoinedTable()) + { + table_join.addJoinedColumn(column); + } + ActionsDAGPtr left_convert_actions = nullptr; + ActionsDAGPtr right_convert_actions = nullptr; + std::tie(left_convert_actions, right_convert_actions) = table_join.createConvertingActions( + left.getCurrentDataStream().header.getColumnsWithTypeAndName(), right.getCurrentDataStream().header.getColumnsWithTypeAndName()); + + if (right_convert_actions) + { + auto converting_step = std::make_unique(right.getCurrentDataStream(), right_convert_actions); + converting_step->setStepDescription("Convert joined columns"); + steps.emplace_back(converting_step.get()); + right.addStep(std::move(converting_step)); + } + + if (left_convert_actions) + { + auto converting_step = std::make_unique(left.getCurrentDataStream(), left_convert_actions); + converting_step->setStepDescription("Convert joined columns"); + steps.emplace_back(converting_step.get()); + left.addStep(std::move(converting_step)); + } +} + + +void registerCrossRelParser(RelParserFactory & factory) +{ + auto builder = [](SerializedPlanParser * plan_paser) { return std::make_shared(plan_paser); }; + factory.registerBuilder(substrait::Rel::RelTypeCase::kCross, builder); +} + +} diff --git a/cpp-ch/local-engine/Parser/CrossRelParser.h b/cpp-ch/local-engine/Parser/CrossRelParser.h new file mode 100644 index 000000000000..f1cd60385e26 --- /dev/null +++ b/cpp-ch/local-engine/Parser/CrossRelParser.h @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace DB +{ +class TableJoin; +} + +namespace local_engine +{ + +class StorageJoinFromReadBuffer; + + +class CrossRelParser : public RelParser +{ +public: + explicit CrossRelParser(SerializedPlanParser * plan_paser_); + ~CrossRelParser() override = default; + + DB::QueryPlanPtr + parse(DB::QueryPlanPtr query_plan, const substrait::Rel & sort_rel, std::list & rel_stack_) override; + + DB::QueryPlanPtr parseOp(const substrait::Rel & rel, std::list & rel_stack) override; + + const substrait::Rel & getSingleInput(const substrait::Rel & rel) override; + +private: + std::unordered_map & function_mapping; + ContextPtr & context; + std::vector & extra_plan_holder; + + + DB::QueryPlanPtr parseJoin(const substrait::CrossRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right); + void renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & right, const StorageJoinFromReadBuffer & storage_join); + void addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right); + void addPostFilter(DB::QueryPlan & query_plan, const substrait::CrossRel & join); + bool applyJoinFilter( + DB::TableJoin & table_join, const substrait::CrossRel & join_rel, DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition); +}; + +} diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index a6a146954d6f..03734a2a9f0d 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -15,6 +15,7 @@ * limitations under the License. */ #include "JoinRelParser.h" + #include #include #include @@ -30,6 +31,7 @@ #include #include #include +#include #include #include @@ -98,51 +100,15 @@ JoinOptimizationInfo parseJoinOptimizationInfo(const substrait::JoinRel & join) return info; } - -void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols) -{ - ActionsDAGPtr project = std::make_shared(plan.getCurrentDataStream().header.getNamesAndTypesList()); - NamesWithAliases project_cols; - for (const auto & col : cols) - { - project_cols.emplace_back(NameWithAlias(col, col)); - } - project->project(project_cols); - QueryPlanStepPtr project_step = std::make_unique(plan.getCurrentDataStream(), project); - project_step->setStepDescription("Reorder Join Output"); - plan.addStep(std::move(project_step)); -} - namespace local_engine { - -std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type) -{ - switch (join_type) - { - case substrait::JoinRel_JoinType_JOIN_TYPE_INNER: - return {DB::JoinKind::Inner, DB::JoinStrictness::All}; - case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: - return {DB::JoinKind::Left, DB::JoinStrictness::Semi}; - case substrait::JoinRel_JoinType_JOIN_TYPE_ANTI: - return {DB::JoinKind::Left, DB::JoinStrictness::Anti}; - case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: - return {DB::JoinKind::Left, DB::JoinStrictness::All}; - case substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT: - return {DB::JoinKind::Right, DB::JoinStrictness::All}; - case substrait::JoinRel_JoinType_JOIN_TYPE_OUTER: - return {DB::JoinKind::Full, DB::JoinStrictness::All}; - default: - throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); - } -} std::shared_ptr createDefaultTableJoin(substrait::JoinRel_JoinType join_type) { auto & global_context = SerializedPlanParser::global_context; auto table_join = std::make_shared( global_context->getSettings(), global_context->getGlobalTemporaryVolume(), global_context->getTempDataOnDisk()); - std::pair kind_and_strictness = getJoinKindAndStrictness(join_type); + std::pair kind_and_strictness = JoinUtil::getJoinKindAndStrictness(join_type); table_join->setKind(kind_and_strictness.first); table_join->setStrictness(kind_and_strictness.second); return table_join; @@ -436,7 +402,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q query_plan = std::make_unique(); query_plan->unitePlans(std::move(join_step), {std::move(plans)}); } - reorderJoinOutput(*query_plan, after_join_names); + JoinUtil::reorderJoinOutput(*query_plan, after_join_names); return query_plan; } diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h b/cpp-ch/local-engine/Parser/JoinRelParser.h index c423f43908e7..15468b54b6f4 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.h +++ b/cpp-ch/local-engine/Parser/JoinRelParser.h @@ -31,8 +31,6 @@ namespace local_engine class StorageJoinFromReadBuffer; -std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type); - class JoinRelParser : public RelParser { public: diff --git a/cpp-ch/local-engine/Parser/RelParser.cpp b/cpp-ch/local-engine/Parser/RelParser.cpp index 282339c4d641..f651146a391d 100644 --- a/cpp-ch/local-engine/Parser/RelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParser.cpp @@ -156,6 +156,7 @@ void registerAggregateParser(RelParserFactory & factory); void registerProjectRelParser(RelParserFactory & factory); void registerJoinRelParser(RelParserFactory & factory); void registerFilterRelParser(RelParserFactory & factory); +void registerCrossRelParser(RelParserFactory & factory); void registerRelParsers() { @@ -166,6 +167,7 @@ void registerRelParsers() registerAggregateParser(factory); registerProjectRelParser(factory); registerJoinRelParser(factory); + registerCrossRelParser(factory); registerFilterRelParser(factory); } } diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 8c60c6e500a9..1174faf1fe58 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -528,6 +528,7 @@ QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list case substrait::Rel::RelTypeCase::kSort: case substrait::Rel::RelTypeCase::kWindow: case substrait::Rel::RelTypeCase::kJoin: + case substrait::Rel::RelTypeCase::kCross: case substrait::Rel::RelTypeCase::kExpand: { auto op_parser = RelParserFactory::instance().getBuilder(rel.rel_type_case())(this); query_plan = op_parser->parseOp(rel, rel_stack); diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 477fdb1f6d44..6364faae1513 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -255,6 +255,7 @@ class SerializedPlanParser friend class FunctionExecutor; friend class NonNullableColumnsResolver; friend class JoinRelParser; + friend class CrossRelParser; friend class MergeTreeRelParser; friend class ProjectRelParser; diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 695fc8585538..627e6154cdcf 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -1121,13 +1121,12 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild const auto named_struct_a = local_engine::getByteArrayElementsSafe(env, named_struct); const std::string::size_type struct_size = named_struct_a.length(); std::string struct_string{reinterpret_cast(named_struct_a.elems()), struct_size}; - const auto join_type = static_cast(join_type_); const jsize length = env->GetArrayLength(in); local_engine::ReadBufferFromByteArray read_buffer_from_java_array(in, length); DB::CompressedReadBuffer input(read_buffer_from_java_array); local_engine::configureCompressedReadBuffer(input); const auto * obj = make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin( - hash_table_id, input, row_count_, join_key, join_type, has_mixed_join_condition, struct_string)); + hash_table_id, input, row_count_, join_key, join_type_, has_mixed_join_condition, struct_string)); return obj->instance(); LOCAL_ENGINE_JNI_METHOD_END(env, 0) } diff --git a/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto b/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto index 0e51baf5ad4c..3813de868445 100644 --- a/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto +++ b/gluten-core/src/main/resources/substrait/proto/substrait/algebra.proto @@ -259,6 +259,7 @@ message CrossRel { JOIN_TYPE_OUTER = 2; JOIN_TYPE_LEFT = 3; JOIN_TYPE_RIGHT = 4; + JOIN_TYPE_LEFT_SEMI = 5; } substrait.extensions.AdvancedExtension advanced_extension = 10; diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index d159486373ac..8ddcc7b7f93e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -144,7 +144,7 @@ trait BackendSettingsApi { def supportCartesianProductExec(): Boolean = false - def supportBroadcastNestedLoopJoinExec(): Boolean = false + def supportBroadcastNestedLoopJoinExec(): Boolean = true def supportSampleExec(): Boolean = false diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index 092612ea7340..b90c1ad8b6e7 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -16,20 +16,22 @@ */ package org.apache.gluten.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.SubstraitContext +import org.apache.gluten.substrait.{JoinParams, SubstraitContext} import org.apache.gluten.utils.SubstraitUtil import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.BaseJoinExec import org.apache.spark.sql.execution.metric.SQLMetric +import com.google.protobuf.Any import io.substrait.proto.CrossRel abstract class BroadcastNestedLoopJoinExecTransformer( @@ -49,7 +51,8 @@ abstract class BroadcastNestedLoopJoinExecTransformer( private lazy val substraitJoinType: CrossRel.JoinType = SubstraitUtil.toCrossRelSubstrait(joinType) - private lazy val buildTableId: String = "BuildTable-" + buildPlan.id + // Unique ID for builded table + lazy val buildBroadcastTableId: String = "BuiltBNLJBroadcastTable-" + buildPlan.id // Hint substrait to switch the left and right, // since we assume always build right side in substrait. @@ -79,6 +82,10 @@ abstract class BroadcastNestedLoopJoinExecTransformer( left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case LeftExistence(_) => + left.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) case x => throw new IllegalArgumentException(s"${getClass.getSimpleName} not take $x as the JoinType") } @@ -103,6 +110,8 @@ abstract class BroadcastNestedLoopJoinExecTransformer( } } + def genJoinParameters(): Any = Any.getDefaultInstance + override protected def doTransform(context: SubstraitContext): TransformContext = { val streamedPlanContext = streamedPlan.asInstanceOf[TransformSupport].transform(context) val (inputStreamedRelNode, inputStreamedOutput) = @@ -113,6 +122,10 @@ abstract class BroadcastNestedLoopJoinExecTransformer( (buildPlanContext.root, buildPlanContext.outputAttributes) val operatorId = context.nextOperatorId(this.nodeName) + val joinParams = new JoinParams + if (condition.isDefined) { + joinParams.isWithCondition = true + } val crossRel = JoinUtils.createCrossRel( substraitJoinType, @@ -122,14 +135,17 @@ abstract class BroadcastNestedLoopJoinExecTransformer( inputStreamedOutput, inputBuildOutput, context, - operatorId + operatorId, + genJoinParameters() ) + context.registerJoinParam(operatorId, joinParams) + val projectRelPostJoinRel = JoinUtils.createProjectRelPostJoinRel( needSwitchChildren, joinType, - inputStreamedOutput, - inputBuildOutput, + streamedPlan.output, + buildPlan.output, context, operatorId, crossRel, @@ -145,18 +161,39 @@ abstract class BroadcastNestedLoopJoinExecTransformer( inputBuildOutput) } + def validateJoinTypeAndBuildSide(): ValidationResult = { + val result = joinType match { + case _: InnerLike | LeftOuter | RightOuter => ValidationResult.ok + case _ => + ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") + } + + if (!result.isValid) { + return result + } + + (joinType, buildSide) match { + case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => + ValidationResult.notOk(s"$joinType join is not supported with $buildSide") + case _ => ValidationResult.ok // continue + } + } + override protected def doValidateInternal(): ValidationResult = { - if (!BackendsApiManager.getSettings.supportBroadcastNestedLoopJoinExec()) { - return ValidationResult.notOk("Broadcast Nested Loop join is not supported in this backend") + if (!GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled) { + return ValidationResult.notOk( + s"Config ${GlutenConfig.BROADCAST_NESTED_LOOP_JOIN_TRANSFORMER_ENABLED.key} not enabled") } + if (substraitJoinType == CrossRel.JoinType.UNRECOGNIZED) { return ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") } - (joinType, buildSide) match { - case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => - return ValidationResult.notOk(s"$joinType join is not supported with $buildSide") - case _ => // continue + + val validateResult = validateJoinTypeAndBuildSide() + if (!validateResult.isValid) { + return validateResult } + val substraitContext = new SubstraitContext val crossRel = JoinUtils.createCrossRel( @@ -168,6 +205,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer( buildPlan.output, substraitContext, substraitContext.nextOperatorId(this.nodeName), + genJoinParameters(), validation = true ) doNativeValidation(substraitContext, crossRel) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala index eb2c0bfd7229..9dd73800e29b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala @@ -337,6 +337,7 @@ object JoinUtils { inputBuildOutput: Seq[Attribute], substraitContext: SubstraitContext, operatorId: java.lang.Long, + joinParameters: Any, validation: Boolean = false ): RelNode = { val expressionNode = condition.map { @@ -346,7 +347,7 @@ object JoinUtils { .doTransform(substraitContext.registeredFunction) } val extensionNode = - JoinUtils.createExtensionNode(inputStreamedOutput ++ inputBuildOutput, validation) + createJoinExtensionNode(joinParameters, inputStreamedOutput ++ inputBuildOutput) RelBuilder.makeCrossRel( inputStreamedRelNode, diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala index 15fc8bea7054..b7a30f7e177a 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala @@ -20,10 +20,14 @@ import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, Trans import org.apache.gluten.utils.{LogLevelUtil, PlanUtil} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans.{JoinType, LeftSemi} import org.apache.spark.sql.catalyst.rules.{PlanChangeLogger, Rule} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeLike} +import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec import org.apache.spark.sql.internal.SQLConf object MiscColumnarRules { @@ -190,4 +194,22 @@ object MiscColumnarRules { child } } + + // Remove unnecessary bnlj like sql: + // ``` select l.* from l left semi join r; ``` + // The result always is left table. + case class RemoveBroadcastNestedLoopJoin() extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case BroadcastNestedLoopJoinExec( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) if condition.isEmpty && joinType == LeftSemi => + buildSide match { + case BuildLeft => right + case BuildRight => left + } + } + } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala index 959bf808aba4..b6236ae9a536 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala @@ -147,8 +147,6 @@ object Validators { case p: SortAggregateExec if !settings.replaceSortAggWithHashAgg => fail(p) case p: CartesianProductExec if !settings.supportCartesianProductExec() => fail(p) - case p: BroadcastNestedLoopJoinExec if !settings.supportBroadcastNestedLoopJoinExec() => - fail(p) case p: TakeOrderedAndProjectExec if !settings.supportColumnarShuffleExec() => fail(p) case _ => pass() } diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala index 9671c7a6bca2..e8e7ce06feaf 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala @@ -48,6 +48,10 @@ object SubstraitUtil { // the left and right relations are exchanged and the // join type is reverted. CrossRel.JoinType.JOIN_TYPE_LEFT + case LeftSemi => + CrossRel.JoinType.JOIN_TYPE_LEFT_SEMI + case FullOuter => + CrossRel.JoinType.JOIN_TYPE_OUTER case _ => CrossRel.JoinType.UNRECOGNIZED } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala index df1c87cb0ccc..4da7a2f6f11a 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.Statistics -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, IdentityBroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf @@ -134,13 +134,6 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) } override protected def doValidateInternal(): ValidationResult = { - // CH backend does not support IdentityBroadcastMode used in BNLJ - if ( - mode == IdentityBroadcastMode && !BackendsApiManager.getSettings - .supportBroadcastNestedLoopJoinExec() - ) { - return ValidationResult.notOk("This backend does not support IdentityBroadcastMode and BNLJ") - } BackendsApiManager.getValidatorApiInstance .doSchemaValidate(schema) .map { diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala index 6860d6a12958..e724cf31c689 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/gluten/GlutenFallbackSuite.scala @@ -110,7 +110,7 @@ class GlutenFallbackSuite extends GlutenSQLTestsTrait with AdaptiveSparkPlanHelp execution.get.fallbackNodeToReason.head._2 .contains("FullOuter join is not supported with BroadcastNestedLoopJoin")) } else { - assert(execution.get.numFallbackNodes == 2) + assert(execution.get.numFallbackNodes == 0) } }