Skip to content

Commit

Permalink
all fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
loneylee committed Jul 8, 2024
1 parent 81da238 commit 84a10c2
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,17 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.ValidationResult

import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.{Any, StringValue}

case class CHBroadcastNestedLoopJoinExecTransformer(
left: SparkPlan,
right: SparkPlan,
Expand All @@ -42,8 +40,6 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
joinType,
condition
) {
// Unique ID for builded table
lazy val buildBroadcastTableId: String = "BuiltBNLJBroadcastTable-" + buildPlan.id

override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = {
val streamedRDD = getColumnarInputRDDs(streamedPlan)
Expand Down Expand Up @@ -81,17 +77,17 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
res
}

override def genJoinParameters(): Any = {
val joinParametersStr = new StringBuffer("JoinParameters:")
joinParametersStr
.append("buildHashTableId=")
.append(buildBroadcastTableId)
.append("\n")
val message = StringValue
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
override def validateJoinTypeAndBuildSide(): ValidationResult = {
joinType match {
case _: InnerLike =>
case _ =>
if (condition.isDefined) {
return ValidationResult.notOk(
s"Broadcast Nested Loop join is not supported join type $joinType with conditions")
}
}

ValidationResult.ok
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec}
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}

import scala.util.control.Breaks.{break, breakable}

Expand Down Expand Up @@ -103,6 +103,10 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl
GlutenConfig.getConf.enableColumnarBroadcastJoin &&
GlutenConfig.getConf.enableColumnarBroadcastExchange

private val enableColumnarBroadcastNestedLoopJoin: Boolean =
GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled &&
GlutenConfig.getConf.enableColumnarBroadcastExchange

override def apply(plan: SparkPlan): SparkPlan = {
plan.foreachUp {
p =>
Expand Down Expand Up @@ -138,63 +142,9 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl
case BuildRight => bhj.right
}

val maybeExchange = buildSidePlan
.find {
case BroadcastExchangeExec(_, _) => true
case _ => false
}
.map(_.asInstanceOf[BroadcastExchangeExec])

maybeExchange match {
case Some(exchange @ BroadcastExchangeExec(mode, child)) =>
isBhjTransformable.tagOnFallback(bhj)
if (!isBhjTransformable.isValid) {
FallbackTags.add(exchange, isBhjTransformable)
}
case None =>
// we are in AQE, find the hidden exchange
// FIXME did we consider the case that AQE: OFF && Reuse: ON ?
var maybeHiddenExchange: Option[BroadcastExchangeLike] = None
breakable {
buildSidePlan.foreach {
case e: BroadcastExchangeLike =>
maybeHiddenExchange = Some(e)
break
case t: BroadcastQueryStageExec =>
t.plan.foreach {
case e2: BroadcastExchangeLike =>
maybeHiddenExchange = Some(e2)
break
case r: ReusedExchangeExec =>
r.child match {
case e2: BroadcastExchangeLike =>
maybeHiddenExchange = Some(e2)
break
case _ =>
}
case _ =>
}
case _ =>
}
}
// restriction to force the hidden exchange to be found
val exchange = maybeHiddenExchange.get
// to conform to the underlying exchange's type, columnar or vanilla
exchange match {
case BroadcastExchangeExec(mode, child) =>
FallbackTags.add(
bhj,
"it's a materialized broadcast exchange or reused broadcast exchange")
case ColumnarBroadcastExchangeExec(mode, child) =>
if (!isBhjTransformable.isValid) {
throw new IllegalStateException(
s"BroadcastExchange has already been" +
s" transformed to columnar version but BHJ is determined as" +
s" non-transformable: ${bhj.toString()}")
}
}
}
preTagBroadcastExchangeFallback(bhj, buildSidePlan, isBhjTransformable)
}
case bnlj: BroadcastNestedLoopJoinExec => applyBNLJFallback(bnlj)
case _ =>
}
} catch {
Expand All @@ -207,4 +157,88 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl
}
plan
}

private def applyBNLJFallback(bnlj: BroadcastNestedLoopJoinExec) = {
if (!enableColumnarBroadcastNestedLoopJoin) {
FallbackTags.add(bnlj, "columnar BroadcastJoin is not enabled in BroadcastNestedLoopJoinExec")
}

val transformer = BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastNestedLoopJoinExecTransformer(
bnlj.left,
bnlj.right,
bnlj.buildSide,
bnlj.joinType,
bnlj.condition)

val isBNLJTransformable = transformer.doValidate()
val buildSidePlan = bnlj.buildSide match {
case BuildLeft => bnlj.left
case BuildRight => bnlj.right
}

preTagBroadcastExchangeFallback(bnlj, buildSidePlan, isBNLJTransformable)
}

private def preTagBroadcastExchangeFallback(
plan: SparkPlan,
buildSidePlan: SparkPlan,
isTransformable: ValidationResult): Unit = {
val maybeExchange = buildSidePlan
.find {
case BroadcastExchangeExec(_, _) => true
case _ => false
}
.map(_.asInstanceOf[BroadcastExchangeExec])

maybeExchange match {
case Some(exchange @ BroadcastExchangeExec(_, _)) =>
isTransformable.tagOnFallback(plan)
if (!isTransformable.isValid) {
FallbackTags.add(exchange, isTransformable)
}
case None =>
// we are in AQE, find the hidden exchange
// FIXME did we consider the case that AQE: OFF && Reuse: ON ?
var maybeHiddenExchange: Option[BroadcastExchangeLike] = None
breakable {
buildSidePlan.foreach {
case e: BroadcastExchangeLike =>
maybeHiddenExchange = Some(e)
break
case t: BroadcastQueryStageExec =>
t.plan.foreach {
case e2: BroadcastExchangeLike =>
maybeHiddenExchange = Some(e2)
break
case r: ReusedExchangeExec =>
r.child match {
case e2: BroadcastExchangeLike =>
maybeHiddenExchange = Some(e2)
break
case _ =>
}
case _ =>
}
case _ =>
}
}
// restriction to force the hidden exchange to be found
val exchange = maybeHiddenExchange.get
// to conform to the underlying exchange's type, columnar or vanilla
exchange match {
case BroadcastExchangeExec(mode, child) =>
FallbackTags.add(
plan,
"it's a materialized broadcast exchange or reused broadcast exchange")
case ColumnarBroadcastExchangeExec(mode, child) =>
if (!isTransformable.isValid) {
throw new IllegalStateException(
s"BroadcastExchange has already been" +
s" transformed to columnar version but BHJ is determined as" +
s" non-transformable: ${plan.toString()}")
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.metrics.MetricsUpdater
Expand All @@ -30,7 +31,7 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.BaseJoinExec
import org.apache.spark.sql.execution.metric.SQLMetric

import com.google.protobuf.Any
import com.google.protobuf.{Any, StringValue}
import io.substrait.proto.CrossRel

abstract class BroadcastNestedLoopJoinExecTransformer(
Expand All @@ -50,7 +51,8 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
private lazy val substraitJoinType: CrossRel.JoinType =
SubstraitUtil.toCrossRelSubstrait(joinType)

private lazy val buildTableId: String = "BuildTable-" + buildPlan.id
// Unique ID for builded table
lazy val buildBroadcastTableId: String = "BuiltBNLJBroadcastTable-" + buildPlan.id

// Hint substrait to switch the left and right,
// since we assume always build right side in substrait.
Expand Down Expand Up @@ -108,7 +110,19 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
}
}

def genJoinParameters(): Any = Any.getDefaultInstance
def genJoinParameters(): Any = {
// for ch
val joinParametersStr = new StringBuffer("JoinParameters:")
joinParametersStr
.append("buildHashTableId=")
.append(buildBroadcastTableId)
.append("\n")
val message = StringValue
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
}

override protected def doTransform(context: SubstraitContext): TransformContext = {
val streamedPlanContext = streamedPlan.asInstanceOf[TransformSupport].transform(context)
Expand Down Expand Up @@ -159,19 +173,35 @@ abstract class BroadcastNestedLoopJoinExecTransformer(
inputBuildOutput)
}

def validateJoinTypeAndBuildSide(): ValidationResult = {
joinType match {
case _: InnerLike | LeftOuter | RightOuter => ValidationResult.ok
case _ =>
ValidationResult.notOk(
s"Broadcast Nested Loop join is not supported join type $joinType in this backend")
}

(joinType, buildSide) match {
case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) =>
ValidationResult.notOk(s"$joinType join is not supported with $buildSide")
case _ => ValidationResult.ok // continue
}
}

override protected def doValidateInternal(): ValidationResult = {
if (!BackendsApiManager.getSettings.supportBroadcastNestedJoinJoinType(joinType)) {
if (!GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled)
return ValidationResult.notOk(
s"Broadcast Nested Loop join is not supported join type $joinType in this backend")
}
s"Config ${GlutenConfig.BROADCAST_NESTED_LOOP_JOIN_TRANSFORMER_ENABLED.key} not enabled")

if (substraitJoinType == CrossRel.JoinType.UNRECOGNIZED) {
return ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin")
}
(joinType, buildSide) match {
case (LeftOuter, BuildLeft) | (RightOuter, BuildRight) =>
return ValidationResult.notOk(s"$joinType join is not supported with $buildSide")
case _ => // continue

val validateResult = validateJoinTypeAndBuildSide()
if (!validateResult.isValid) {
return validateResult
}

val substraitContext = new SubstraitContext

val crossRel = JoinUtils.createCrossRel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,6 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
}

override protected def doValidateInternal(): ValidationResult = {
// // CH backend does not support IdentityBroadcastMode used in BNLJ
// if (
// mode == IdentityBroadcastMode && !BackendsApiManager.getSettings
// .supportBroadcastNestedLoopJoinExec()
// ) {
// return ValidationResult.notOk("This backend does not support IdentityBroadcastMode and BNLJ")
// }

BackendsApiManager.getValidatorApiInstance
.doSchemaValidate(schema)
.map {
Expand Down

0 comments on commit 84a10c2

Please sign in to comment.