diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala index ac7cf67d8f306..c408a223784b6 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala @@ -16,19 +16,17 @@ */ package org.apache.gluten.execution -import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.extension.ValidationResult import org.apache.spark.rdd.RDD import org.apache.spark.rpc.GlutenDriverEndpoint import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.BuildSide -import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch -import com.google.protobuf.{Any, StringValue} - case class CHBroadcastNestedLoopJoinExecTransformer( left: SparkPlan, right: SparkPlan, @@ -42,8 +40,6 @@ case class CHBroadcastNestedLoopJoinExecTransformer( joinType, condition ) { - // Unique ID for builded table - lazy val buildBroadcastTableId: String = "BuiltBNLJBroadcastTable-" + buildPlan.id override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = { val streamedRDD = getColumnarInputRDDs(streamedPlan) @@ -81,17 +77,17 @@ case class CHBroadcastNestedLoopJoinExecTransformer( res } - override def genJoinParameters(): Any = { - val joinParametersStr = new StringBuffer("JoinParameters:") - joinParametersStr - .append("buildHashTableId=") - .append(buildBroadcastTableId) - .append("\n") - val message = StringValue - .newBuilder() - .setValue(joinParametersStr.toString) - .build() - BackendsApiManager.getTransformerApiInstance.packPBMessage(message) + override def validateJoinTypeAndBuildSide(): ValidationResult = { + joinType match { + case _: InnerLike => + case _ => + if (condition.isDefined) { + return ValidationResult.notOk( + s"Broadcast Nested Loop join is not supported join type $joinType with conditions") + } + } + + ValidationResult.ok } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala index 59c2d6494bdba..f7d9a6dbe7eac 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/FallbackBroadcaseHashJoinRules.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec} -import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec} import scala.util.control.Breaks.{break, breakable} @@ -103,6 +103,10 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl GlutenConfig.getConf.enableColumnarBroadcastJoin && GlutenConfig.getConf.enableColumnarBroadcastExchange + private val enableColumnarBroadcastNestedLoopJoin: Boolean = + GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled && + GlutenConfig.getConf.enableColumnarBroadcastExchange + override def apply(plan: SparkPlan): SparkPlan = { plan.foreachUp { p => @@ -138,63 +142,9 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl case BuildRight => bhj.right } - val maybeExchange = buildSidePlan - .find { - case BroadcastExchangeExec(_, _) => true - case _ => false - } - .map(_.asInstanceOf[BroadcastExchangeExec]) - - maybeExchange match { - case Some(exchange @ BroadcastExchangeExec(mode, child)) => - isBhjTransformable.tagOnFallback(bhj) - if (!isBhjTransformable.isValid) { - FallbackTags.add(exchange, isBhjTransformable) - } - case None => - // we are in AQE, find the hidden exchange - // FIXME did we consider the case that AQE: OFF && Reuse: ON ? - var maybeHiddenExchange: Option[BroadcastExchangeLike] = None - breakable { - buildSidePlan.foreach { - case e: BroadcastExchangeLike => - maybeHiddenExchange = Some(e) - break - case t: BroadcastQueryStageExec => - t.plan.foreach { - case e2: BroadcastExchangeLike => - maybeHiddenExchange = Some(e2) - break - case r: ReusedExchangeExec => - r.child match { - case e2: BroadcastExchangeLike => - maybeHiddenExchange = Some(e2) - break - case _ => - } - case _ => - } - case _ => - } - } - // restriction to force the hidden exchange to be found - val exchange = maybeHiddenExchange.get - // to conform to the underlying exchange's type, columnar or vanilla - exchange match { - case BroadcastExchangeExec(mode, child) => - FallbackTags.add( - bhj, - "it's a materialized broadcast exchange or reused broadcast exchange") - case ColumnarBroadcastExchangeExec(mode, child) => - if (!isBhjTransformable.isValid) { - throw new IllegalStateException( - s"BroadcastExchange has already been" + - s" transformed to columnar version but BHJ is determined as" + - s" non-transformable: ${bhj.toString()}") - } - } - } + preTagBroadcastExchangeFallback(bhj, buildSidePlan, isBhjTransformable) } + case bnlj: BroadcastNestedLoopJoinExec => applyBNLJFallback(bnlj) case _ => } } catch { @@ -207,4 +157,88 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl } plan } + + private def applyBNLJFallback(bnlj: BroadcastNestedLoopJoinExec) = { + if (!enableColumnarBroadcastNestedLoopJoin) { + FallbackTags.add(bnlj, "columnar BroadcastJoin is not enabled in BroadcastNestedLoopJoinExec") + } + + val transformer = BackendsApiManager.getSparkPlanExecApiInstance + .genBroadcastNestedLoopJoinExecTransformer( + bnlj.left, + bnlj.right, + bnlj.buildSide, + bnlj.joinType, + bnlj.condition) + + val isBNLJTransformable = transformer.doValidate() + val buildSidePlan = bnlj.buildSide match { + case BuildLeft => bnlj.left + case BuildRight => bnlj.right + } + + preTagBroadcastExchangeFallback(bnlj, buildSidePlan, isBNLJTransformable) + } + + private def preTagBroadcastExchangeFallback( + plan: SparkPlan, + buildSidePlan: SparkPlan, + isTransformable: ValidationResult): Unit = { + val maybeExchange = buildSidePlan + .find { + case BroadcastExchangeExec(_, _) => true + case _ => false + } + .map(_.asInstanceOf[BroadcastExchangeExec]) + + maybeExchange match { + case Some(exchange @ BroadcastExchangeExec(_, _)) => + isTransformable.tagOnFallback(plan) + if (!isTransformable.isValid) { + FallbackTags.add(exchange, isTransformable) + } + case None => + // we are in AQE, find the hidden exchange + // FIXME did we consider the case that AQE: OFF && Reuse: ON ? + var maybeHiddenExchange: Option[BroadcastExchangeLike] = None + breakable { + buildSidePlan.foreach { + case e: BroadcastExchangeLike => + maybeHiddenExchange = Some(e) + break + case t: BroadcastQueryStageExec => + t.plan.foreach { + case e2: BroadcastExchangeLike => + maybeHiddenExchange = Some(e2) + break + case r: ReusedExchangeExec => + r.child match { + case e2: BroadcastExchangeLike => + maybeHiddenExchange = Some(e2) + break + case _ => + } + case _ => + } + case _ => + } + } + // restriction to force the hidden exchange to be found + val exchange = maybeHiddenExchange.get + // to conform to the underlying exchange's type, columnar or vanilla + exchange match { + case BroadcastExchangeExec(mode, child) => + FallbackTags.add( + plan, + "it's a materialized broadcast exchange or reused broadcast exchange") + case ColumnarBroadcastExchangeExec(mode, child) => + if (!isTransformable.isValid) { + throw new IllegalStateException( + s"BroadcastExchange has already been" + + s" transformed to columnar version but BHJ is determined as" + + s" non-transformable: ${plan.toString()}") + } + } + } + } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index 267e45920b8f9..58b6b7d4c2d28 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.execution +import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater @@ -30,7 +31,7 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.BaseJoinExec import org.apache.spark.sql.execution.metric.SQLMetric -import com.google.protobuf.Any +import com.google.protobuf.{Any, StringValue} import io.substrait.proto.CrossRel abstract class BroadcastNestedLoopJoinExecTransformer( @@ -50,7 +51,8 @@ abstract class BroadcastNestedLoopJoinExecTransformer( private lazy val substraitJoinType: CrossRel.JoinType = SubstraitUtil.toCrossRelSubstrait(joinType) - private lazy val buildTableId: String = "BuildTable-" + buildPlan.id + // Unique ID for builded table + lazy val buildBroadcastTableId: String = "BuiltBNLJBroadcastTable-" + buildPlan.id // Hint substrait to switch the left and right, // since we assume always build right side in substrait. @@ -108,7 +110,19 @@ abstract class BroadcastNestedLoopJoinExecTransformer( } } - def genJoinParameters(): Any = Any.getDefaultInstance + def genJoinParameters(): Any = { + // for ch + val joinParametersStr = new StringBuffer("JoinParameters:") + joinParametersStr + .append("buildHashTableId=") + .append(buildBroadcastTableId) + .append("\n") + val message = StringValue + .newBuilder() + .setValue(joinParametersStr.toString) + .build() + BackendsApiManager.getTransformerApiInstance.packPBMessage(message) + } override protected def doTransform(context: SubstraitContext): TransformContext = { val streamedPlanContext = streamedPlan.asInstanceOf[TransformSupport].transform(context) @@ -159,19 +173,35 @@ abstract class BroadcastNestedLoopJoinExecTransformer( inputBuildOutput) } + def validateJoinTypeAndBuildSide(): ValidationResult = { + joinType match { + case _: InnerLike | LeftOuter | RightOuter => ValidationResult.ok + case _ => + ValidationResult.notOk( + s"Broadcast Nested Loop join is not supported join type $joinType in this backend") + } + + (joinType, buildSide) match { + case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => + ValidationResult.notOk(s"$joinType join is not supported with $buildSide") + case _ => ValidationResult.ok // continue + } + } + override protected def doValidateInternal(): ValidationResult = { - if (!BackendsApiManager.getSettings.supportBroadcastNestedJoinJoinType(joinType)) { + if (!GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled) return ValidationResult.notOk( - s"Broadcast Nested Loop join is not supported join type $joinType in this backend") - } + s"Config ${GlutenConfig.BROADCAST_NESTED_LOOP_JOIN_TRANSFORMER_ENABLED.key} not enabled") + if (substraitJoinType == CrossRel.JoinType.UNRECOGNIZED) { return ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") } - (joinType, buildSide) match { - case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) => - return ValidationResult.notOk(s"$joinType join is not supported with $buildSide") - case _ => // continue + + val validateResult = validateJoinTypeAndBuildSide() + if (!validateResult.isValid) { + return validateResult } + val substraitContext = new SubstraitContext val crossRel = JoinUtils.createCrossRel( diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala index 3d7a075bfb26f..4da7a2f6f11ae 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala @@ -134,14 +134,6 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) } override protected def doValidateInternal(): ValidationResult = { -// // CH backend does not support IdentityBroadcastMode used in BNLJ -// if ( -// mode == IdentityBroadcastMode && !BackendsApiManager.getSettings -// .supportBroadcastNestedLoopJoinExec() -// ) { -// return ValidationResult.notOk("This backend does not support IdentityBroadcastMode and BNLJ") -// } - BackendsApiManager.getValidatorApiInstance .doSchemaValidate(schema) .map {