Skip to content

Commit

Permalink
fix metric
Browse files Browse the repository at this point in the history
fix ci

fix velox ci

tag

fix rebase

all fallback

fix style
  • Loading branch information
loneylee committed Jul 8, 2024
1 parent ebb67fa commit 330878f
Show file tree
Hide file tree
Showing 24 changed files with 273 additions and 515 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,20 @@ public static long build(
return converter.genColumnNameWithExprId(attr);
})
.collect(Collectors.joining(","));

int joinType;
if (broadCastContext.buildHashTableId().startsWith("BuiltBNLJBroadcastTable-")) {
joinType = SubstraitUtil.toCrossRelSubstrait(broadCastContext.joinType()).ordinal();
} else {
joinType = SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal();
}

return nativeBuild(
broadCastContext.buildHashTableId(),
batches,
rowCount,
joinKey,
SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(),
joinType,
broadCastContext.hasMixedFiltCondition(),
toNameStruct(output).toByteArray());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
Expand Down Expand Up @@ -297,4 +298,9 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
}

override def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = true

override def supportBroadcastNestedJoinJoinType: JoinType => Boolean = {
case _: InnerLike | LeftSemi | FullOuter => true
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,6 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
"extraTime" -> SQLMetrics.createTimingMetric(sparkContext, "extra operators time"),
"inputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of waiting for data"),
"outputWaitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time of waiting for output"),
"streamPreProjectionTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of stream side preProjection"),
"buildPreProjectionTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of build side preProjection"),
"postProjectTime" ->
SQLMetrics.createTimingMetric(sparkContext, "time of postProjection"),
"probeTime" ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,17 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.ValidationResult

import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashJoin}
import org.apache.spark.sql.types._
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.{Any, StringValue}

case class CHBroadcastNestedLoopJoinExecTransformer(
left: SparkPlan,
right: SparkPlan,
Expand All @@ -43,32 +40,6 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
joinType,
condition
) {
// Unique ID for builded table
lazy val buildBroadcastTableId: String = "BuiltBroadcastTable-" + buildPlan.id

lazy val (buildKeyExprs, streamedKeyExprs) = {
require(
leftKeys.length == rightKeys.length &&
leftKeys
.map(_.dataType)
.zip(rightKeys.map(_.dataType))
.forall(types => sameType(types._1, types._2)),
"Join keys from two sides should have same length and types"
)
// Spark has an improvement which would patch integer joins keys to a Long value.
// But this improvement would cause add extra project before hash join in velox,
// disabling this improvement as below would help reduce the project.
val (lkeys, rkeys) = if (BackendsApiManager.getSettings.enableJoinKeysRewrite()) {
(HashJoin.rewriteKeyExpr(leftKeys), HashJoin.rewriteKeyExpr(rightKeys))
} else {
(leftKeys, rightKeys)
}
if (needSwitchChildren) {
(lkeys, rkeys)
} else {
(rkeys, lkeys)
}
}

override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = {
val streamedRDD = getColumnarInputRDDs(streamedPlan)
Expand Down Expand Up @@ -106,38 +77,17 @@ case class CHBroadcastNestedLoopJoinExecTransformer(
res
}

def sameType(from: DataType, to: DataType): Boolean = {
(from, to) match {
case (ArrayType(fromElement, _), ArrayType(toElement, _)) =>
sameType(fromElement, toElement)

case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
sameType(fromKey, toKey) &&
sameType(fromValue, toValue)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall {
case (l, r) =>
l.name.equalsIgnoreCase(r.name) &&
sameType(l.dataType, r.dataType)
override def validateJoinTypeAndBuildSide(): ValidationResult = {
joinType match {
case _: InnerLike =>
case _ =>
if (condition.isDefined) {
return ValidationResult.notOk(
s"Broadcast Nested Loop join is not supported join type $joinType with conditions")
}

case (fromDataType, toDataType) => fromDataType == toDataType
}
}

override def genJoinParameters(): Any = {
val joinParametersStr = new StringBuffer("JoinParameters:")
joinParametersStr
.append("buildHashTableId=")
.append(buildBroadcastTableId)
.append("\n")
val message = StringValue
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
ValidationResult.ok
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec}
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec}

import scala.util.control.Breaks.{break, breakable}

Expand Down Expand Up @@ -103,6 +103,10 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl
GlutenConfig.getConf.enableColumnarBroadcastJoin &&
GlutenConfig.getConf.enableColumnarBroadcastExchange

private val enableColumnarBroadcastNestedLoopJoin: Boolean =
GlutenConfig.getConf.broadcastNestedLoopJoinTransformerTransformerEnabled &&
GlutenConfig.getConf.enableColumnarBroadcastExchange

override def apply(plan: SparkPlan): SparkPlan = {
plan.foreachUp {
p =>
Expand Down Expand Up @@ -138,63 +142,9 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl
case BuildRight => bhj.right
}

val maybeExchange = buildSidePlan
.find {
case BroadcastExchangeExec(_, _) => true
case _ => false
}
.map(_.asInstanceOf[BroadcastExchangeExec])

maybeExchange match {
case Some(exchange @ BroadcastExchangeExec(mode, child)) =>
isBhjTransformable.tagOnFallback(bhj)
if (!isBhjTransformable.isValid) {
FallbackTags.add(exchange, isBhjTransformable)
}
case None =>
// we are in AQE, find the hidden exchange
// FIXME did we consider the case that AQE: OFF && Reuse: ON ?
var maybeHiddenExchange: Option[BroadcastExchangeLike] = None
breakable {
buildSidePlan.foreach {
case e: BroadcastExchangeLike =>
maybeHiddenExchange = Some(e)
break
case t: BroadcastQueryStageExec =>
t.plan.foreach {
case e2: BroadcastExchangeLike =>
maybeHiddenExchange = Some(e2)
break
case r: ReusedExchangeExec =>
r.child match {
case e2: BroadcastExchangeLike =>
maybeHiddenExchange = Some(e2)
break
case _ =>
}
case _ =>
}
case _ =>
}
}
// restriction to force the hidden exchange to be found
val exchange = maybeHiddenExchange.get
// to conform to the underlying exchange's type, columnar or vanilla
exchange match {
case BroadcastExchangeExec(mode, child) =>
FallbackTags.add(
bhj,
"it's a materialized broadcast exchange or reused broadcast exchange")
case ColumnarBroadcastExchangeExec(mode, child) =>
if (!isBhjTransformable.isValid) {
throw new IllegalStateException(
s"BroadcastExchange has already been" +
s" transformed to columnar version but BHJ is determined as" +
s" non-transformable: ${bhj.toString()}")
}
}
}
preTagBroadcastExchangeFallback(bhj, buildSidePlan, isBhjTransformable)
}
case bnlj: BroadcastNestedLoopJoinExec => applyBNLJFallback(bnlj)
case _ =>
}
} catch {
Expand All @@ -207,4 +157,88 @@ case class FallbackBroadcastHashJoin(session: SparkSession) extends Rule[SparkPl
}
plan
}

private def applyBNLJFallback(bnlj: BroadcastNestedLoopJoinExec) = {
if (!enableColumnarBroadcastNestedLoopJoin) {
FallbackTags.add(bnlj, "columnar BroadcastJoin is not enabled in BroadcastNestedLoopJoinExec")
}

val transformer = BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastNestedLoopJoinExecTransformer(
bnlj.left,
bnlj.right,
bnlj.buildSide,
bnlj.joinType,
bnlj.condition)

val isBNLJTransformable = transformer.doValidate()
val buildSidePlan = bnlj.buildSide match {
case BuildLeft => bnlj.left
case BuildRight => bnlj.right
}

preTagBroadcastExchangeFallback(bnlj, buildSidePlan, isBNLJTransformable)
}

private def preTagBroadcastExchangeFallback(
plan: SparkPlan,
buildSidePlan: SparkPlan,
isTransformable: ValidationResult): Unit = {
val maybeExchange = buildSidePlan
.find {
case BroadcastExchangeExec(_, _) => true
case _ => false
}
.map(_.asInstanceOf[BroadcastExchangeExec])

maybeExchange match {
case Some(exchange @ BroadcastExchangeExec(_, _)) =>
isTransformable.tagOnFallback(plan)
if (!isTransformable.isValid) {
FallbackTags.add(exchange, isTransformable)
}
case None =>
// we are in AQE, find the hidden exchange
// FIXME did we consider the case that AQE: OFF && Reuse: ON ?
var maybeHiddenExchange: Option[BroadcastExchangeLike] = None
breakable {
buildSidePlan.foreach {
case e: BroadcastExchangeLike =>
maybeHiddenExchange = Some(e)
break
case t: BroadcastQueryStageExec =>
t.plan.foreach {
case e2: BroadcastExchangeLike =>
maybeHiddenExchange = Some(e2)
break
case r: ReusedExchangeExec =>
r.child match {
case e2: BroadcastExchangeLike =>
maybeHiddenExchange = Some(e2)
break
case _ =>
}
case _ =>
}
case _ =>
}
}
// restriction to force the hidden exchange to be found
val exchange = maybeHiddenExchange.get
// to conform to the underlying exchange's type, columnar or vanilla
exchange match {
case BroadcastExchangeExec(mode, child) =>
FallbackTags.add(
plan,
"it's a materialized broadcast exchange or reused broadcast exchange")
case ColumnarBroadcastExchangeExec(mode, child) =>
if (!isTransformable.isValid) {
throw new IllegalStateException(
s"BroadcastExchange has already been" +
s" transformed to columnar version but BHJ is determined as" +
s" non-transformable: ${plan.toString()}")
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,6 @@ class BroadcastNestedLoopJoinMetricsUpdater(val metrics: Map[String, SQLMetric])
var currentIdx = operatorMetrics.metricsList.size() - 1
var totalTime = 0L

// build side pre projection
if (joinParams.buildPreProjectionNeeded) {
metrics("buildPreProjectionTime") +=
(operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong
metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors
totalTime += operatorMetrics.metricsList.get(currentIdx).time
currentIdx -= 1
}

// stream side pre projection
if (joinParams.streamPreProjectionNeeded) {
metrics("streamPreProjectionTime") +=
(operatorMetrics.metricsList.get(currentIdx).time / 1000L).toLong
metrics("outputVectors") += operatorMetrics.metricsList.get(currentIdx).outputVectors
totalTime += operatorMetrics.metricsList.get(currentIdx).time
currentIdx -= 1
}

// update fillingRightJoinSideTime
MetricsUtil
.getAllProcessorList(operatorMetrics.metricsList.get(currentIdx))
Expand All @@ -76,6 +58,8 @@ class BroadcastNestedLoopJoinMetricsUpdater(val metrics: Map[String, SQLMetric])
}
if (processor.name.equalsIgnoreCase("FilterTransform")) {
metrics("conditionTime") += (processor.time / 1000L).toLong
metrics("numOutputRows") += processor.outputRows - processor.inputRows
metrics("outputBytes") += processor.outputBytes - processor.inputBytes
}
if (processor.name.equalsIgnoreCase("JoiningTransform")) {
metrics("probeTime") += (processor.time / 1000L).toLong
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,11 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
Seq("q" + "%d".format(queryNum))
}
val noFallBack = queryNum match {
case i
if i == 10 || i == 16 || i == 35 || i == 45 || i == 77 ||
i == 94 =>
case i if i == 10 || i == 16 || i == 35 || i == 45 || i == 94 =>
// Q10 BroadcastHashJoin, ExistenceJoin
// Q16 ShuffledHashJoin, NOT condition
// Q35 BroadcastHashJoin, ExistenceJoin
// Q45 BroadcastHashJoin, ExistenceJoin
// Q77 CartesianProduct
// Q94 BroadcastHashJoin, LeftSemi, NOT condition
(false, false)
case j if j == 38 || j == 87 =>
Expand All @@ -73,6 +70,9 @@ abstract class GlutenClickHouseTPCDSAbstractSuite
} else {
(false, true)
}
case q77 if q77 == 77 && !isAqe =>
// Q77 CartesianProduct
(false, false)
case other => (true, false)
}
sqlNums.map((_, noFallBack._1, noFallBack._2))
Expand Down
Loading

0 comments on commit 330878f

Please sign in to comment.