Skip to content

Commit

Permalink
[CORE] Simplify WholeStageTransformer and BroadcastBuildSideRDD (#3574)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you authored Nov 1, 2023
1 parent f37e9af commit 0c59a5d
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 178 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.utils.{LogLevelUtil, SubstraitPlanPrinterUtil}
import io.glutenproject.vectorized.{CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator, GeneralInIterator, GeneralOutIterator}

import org.apache.spark.{InterruptibleIterator, Partition, SparkConf, SparkContext, TaskContext}
import org.apache.spark.{InterruptibleIterator, SparkConf, SparkContext, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -267,8 +267,6 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {

/** Compute for BroadcastBuildSideRDD */
override def genBroadcastBuildSideIterator(
split: Partition,
context: TaskContext,
broadcasted: Broadcast[BuildSideRelation],
broadCastContext: BroadCastHashJoinContext): Iterator[ColumnarBatch] = {
CHBroadcastBuildSideCache.getOrBuildBroadcastHashTable(broadcasted, broadCastContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat
import io.glutenproject.utils.Iterators
import io.glutenproject.vectorized._

import org.apache.spark.{Partition, SparkConf, SparkContext, TaskContext}
import org.apache.spark.{SparkConf, SparkContext, TaskContext}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -217,8 +217,6 @@ class IteratorApiImpl extends IteratorApi with Logging {

/** Compute for BroadcastBuildSideRDD */
override def genBroadcastBuildSideIterator(
split: Partition,
context: TaskContext,
broadcasted: Broadcast[BuildSideRelation],
broadCastContext: BroadCastHashJoinContext): Iterator[ColumnarBatch] = {
val relation = broadcasted.value.asReadOnlyCopy(broadCastContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ trait IteratorApi {

/** Compute for BroadcastBuildSideRDD */
def genBroadcastBuildSideIterator(
split: Partition,
context: TaskContext,
broadcasted: broadcast.Broadcast[BuildSideRelation],
broadCastContext: BroadCastHashJoinContext): Iterator[ColumnarBatch]
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,22 @@
*/
package io.glutenproject.execution

import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.exception.GlutenException

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

final private case class BroadcastBuildSideRDDPartition(index: Int) extends Partition

case class BroadcastBuildSideRDD(
@transient private val sc: SparkContext,
broadcasted: broadcast.Broadcast[BuildSideRelation],
broadCastContext: BroadCastHashJoinContext,
numPartitions: Int = -1)
broadCastContext: BroadCastHashJoinContext)
extends RDD[ColumnarBatch](sc, Nil) {

override def getPartitions: Array[Partition] = {
if (numPartitions < 0) {
throw new GlutenException(s"Invalid number of partitions: $numPartitions.")
}
Array.tabulate(numPartitions)(i => BroadcastBuildSideRDDPartition(i))
throw new IllegalStateException("Never reach here")
}

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
BackendsApiManager.getIteratorApiInstance
.genBroadcastBuildSideIterator(split, context, broadcasted, broadCastContext)
throw new IllegalStateException("Never reach here")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.metrics.IMetrics
import io.glutenproject.substrait.plan.PlanBuilder

import org.apache.spark.{OneToOneDependency, Partition, SparkContext, SparkException, TaskContext}
import org.apache.spark.{Partition, SparkContext, SparkException, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution.datasources.PartitionedFile
Expand All @@ -32,8 +32,6 @@ import org.apache.spark.util.ExecutorManager

import io.substrait.proto.Plan

import java.io.Serializable

import scala.collection.mutable

trait BaseGlutenPartition extends Partition with InputPartition {
Expand Down Expand Up @@ -88,82 +86,54 @@ case class GlutenMergeTreePartition(
}

case class FirstZippedPartitionsPartition(
idx: Int,
index: Int,
inputPartition: InputPartition,
@transient private val rdds: Seq[RDD[_]] = Seq())
inputColumnarRDDPartitions: Seq[Partition] = Seq.empty)
extends Partition
with Serializable {

override val index: Int = idx
var partitionValues = rdds.map(rdd => rdd.partitions(idx))

def partitions: Seq[Partition] = partitionValues
}

class GlutenWholeStageColumnarRDD(
@transient sc: SparkContext,
@transient private val inputPartitions: Seq[InputPartition],
var rdds: Seq[RDD[ColumnarBatch]],
var rdds: ColumnarInputRDDsWrapper,
pipelineTime: SQLMetric,
updateInputMetrics: (InputMetricsWrapper) => Unit,
updateNativeMetrics: IMetrics => Unit)
extends RDD[ColumnarBatch](sc, rdds.map(x => new OneToOneDependency(x))) {
extends RDD[ColumnarBatch](sc, rdds.getDependencies) {
val numaBindingInfo = GlutenConfig.getConf.numaBindingInfo

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
ExecutorManager.tryTaskSet(numaBindingInfo)

val inputPartition = castNativePartition(split)
if (rdds.isEmpty) {
BackendsApiManager.getIteratorApiInstance.genFirstStageIterator(
inputPartition,
context,
pipelineTime,
updateInputMetrics,
updateNativeMetrics)
} else {
val partitions = split.asInstanceOf[FirstZippedPartitionsPartition].partitions
val inputIterators =
(rdds.zip(partitions)).map { case (rdd, partition) => rdd.iterator(partition, context) }
BackendsApiManager.getIteratorApiInstance.genFirstStageIterator(
inputPartition,
context,
pipelineTime,
updateInputMetrics,
updateNativeMetrics,
inputIterators
)
}
val (inputPartition, inputColumnarRDDPartitions) = castNativePartition(split)
val inputIterators = rdds.getIterators(inputColumnarRDDPartitions, context)
BackendsApiManager.getIteratorApiInstance.genFirstStageIterator(
inputPartition,
context,
pipelineTime,
updateInputMetrics,
updateNativeMetrics,
inputIterators
)
}

private def castNativePartition(split: Partition): BaseGlutenPartition = split match {
case FirstZippedPartitionsPartition(_, p: BaseGlutenPartition, _) => p
case _ => throw new SparkException(s"[BUG] Not a NativeSubstraitPartition: $split")
private def castNativePartition(split: Partition): (BaseGlutenPartition, Seq[Partition]) = {
split match {
case FirstZippedPartitionsPartition(_, g: BaseGlutenPartition, p) => (g, p)
case _ => throw new SparkException(s"[BUG] Not a NativeSubstraitPartition: $split")
}
}

override def getPreferredLocations(split: Partition): Seq[String] = {
castPartition(split).inputPartition.preferredLocations()
}

private def castPartition(split: Partition): FirstZippedPartitionsPartition = split match {
case p: FirstZippedPartitionsPartition => p
case _ => throw new SparkException(s"[BUG] Not a NativeSubstraitPartition: $split")
castNativePartition(split)._1.preferredLocations()
}

override protected def getPartitions: Array[Partition] = {
if (rdds.isEmpty) {
inputPartitions.zipWithIndex.map {
case (inputPartition, index) => FirstZippedPartitionsPartition(index, inputPartition)
}.toArray
} else {
val numParts = inputPartitions.size
if (!rdds.forall(rdd => rdd.partitions.length == numParts)) {
throw new IllegalArgumentException(
s"Can't zip RDDs with unequal numbers of partitions: ${rdds.map(_.partitions.length)}")
}
Array.tabulate[Partition](numParts) {
i => FirstZippedPartitionsPartition(i, inputPartitions(i), rdds)
}
Array.tabulate[Partition](inputPartitions.size) {
i => FirstZippedPartitionsPartition(i, inputPartitions(i), rdds.getPartitions(i))
}
}

override protected def clearDependencies(): Unit = {
super.clearDependencies()
rdds = null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, HashMap => JHashMap}

import scala.collection.JavaConverters._
import scala.util.control.Breaks.{break, breakable}

trait ColumnarShuffledJoin extends BaseJoinExec {
def isSkewJoin: Boolean
Expand Down Expand Up @@ -420,39 +419,6 @@ abstract class BroadcastHashJoinExecTransformer(
BackendsApiManager.getBroadcastApiInstance
.collectExecutionBroadcastHashTableId(executionId, context.buildHashTableId)

val buildRDD = if (streamedRDD.isEmpty) {
// Stream plan itself contains scan and has no input rdd,
// so the number of partitions cannot be decided here.
BroadcastBuildSideRDD(sparkContext, broadcast, context)
} else {
// Try to get the number of partitions from a non-broadcast RDD.
val nonBroadcastRDD = streamedRDD.find(rdd => !rdd.isInstanceOf[BroadcastBuildSideRDD])
if (nonBroadcastRDD.isDefined) {
BroadcastBuildSideRDD(
sparkContext,
broadcast,
context,
nonBroadcastRDD.orNull.getNumPartitions)
} else {
// When all stream RDDs are broadcast RDD, the number of partitions can be undecided
// because stream plan may contain scan.
var partitions = -1
breakable {
for (rdd <- streamedRDD) {
try {
partitions = rdd.getNumPartitions
break
} catch {
case _: Throwable =>
// The partitions of this RDD is not decided yet.
}
}
}
// If all the stream RDDs are broadcast RDD,
// the number of partitions will be decided later in whole stage transformer.
BroadcastBuildSideRDD(sparkContext, broadcast, context, partitions)
}
}
streamedRDD :+ buildRDD
streamedRDD :+ BroadcastBuildSideRDD(sparkContext, broadcast, context)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import io.glutenproject.substrait.plan.{PlanBuilder, PlanNode}
import io.glutenproject.substrait.rel.RelNode
import io.glutenproject.utils.SubstraitPlanPrinterUtil

import org.apache.spark.SparkConf
import org.apache.spark.{Dependency, OneToOneDependency, Partition, SparkConf, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
Expand Down Expand Up @@ -230,10 +230,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val pipelineTime: SQLMetric = longMetric("pipelineTime")

val buildRelationBatchHolder: mutable.ListBuffer[ColumnarBatch] = mutable.ListBuffer()

val inputRDDs = columnarInputRDDs
val inputRDDs = new ColumnarInputRDDsWrapper(columnarInputRDDs)
// Check if BatchScan exists.
val basicScanExecTransformers = findAllScanTransformers()

Expand Down Expand Up @@ -273,7 +270,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
new GlutenWholeStageColumnarRDD(
sparkContext,
substraitPlanPartitions,
genFirstNewRDDsForBroadcast(inputRDDs, partitionLength),
inputRDDs,
pipelineTime,
leafMetricsUpdater().updateInputMetrics,
BackendsApiManager.getMetricsApiInstance.metricsUpdatingFunction(
Expand All @@ -284,6 +281,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
)
)
} else {
val buildRelationBatchHolder: mutable.ListBuffer[ColumnarBatch] = mutable.ListBuffer()

/**
* the whole stage contains NO BasicScanExecTransformer. this the default case for:
Expand All @@ -297,7 +295,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
}
new WholeStageZippedPartitionsRDD(
sparkContext,
genFinalNewRDDsForBroadcast(inputRDDs),
inputRDDs,
numaBindingInfo,
sparkConf,
resCtx,
Expand Down Expand Up @@ -331,42 +329,53 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
.getOrElse(NoopMetricsUpdater)
}

// Recreate the broadcast build side rdd with matched partition number.
// Used when whole stage transformer contains scan.
def genFirstNewRDDsForBroadcast(
rddSeq: Seq[RDD[ColumnarBatch]],
partitions: Int): Seq[RDD[ColumnarBatch]] = {
rddSeq.map {
case rdd: BroadcastBuildSideRDD =>
rdd.copy(numPartitions = partitions)
case inputRDD =>
inputRDD
override protected def withNewChildInternal(newChild: SparkPlan): WholeStageTransformer =
copy(child = newChild, materializeInput = materializeInput)(transformStageId)
}

/**
* This `columnarInputRDDs` would contain [[BroadcastBuildSideRDD]], but the dependency and
* partition of [[BroadcastBuildSideRDD]] is meaningless. [[BroadcastBuildSideRDD]] should only be
* used to hold the broadcast value and generate iterator for join.
*/
class ColumnarInputRDDsWrapper(columnarInputRDDs: Seq[RDD[ColumnarBatch]]) extends Serializable {
def getDependencies: Seq[Dependency[ColumnarBatch]] = {
assert(
columnarInputRDDs
.filterNot(_.isInstanceOf[BroadcastBuildSideRDD])
.map(_.partitions.length)
.toSet
.size <= 1)

columnarInputRDDs.flatMap {
case _: BroadcastBuildSideRDD => Nil
case rdd => new OneToOneDependency[ColumnarBatch](rdd) :: Nil
}
}

// Recreate the broadcast build side rdd with matched partition number.
// Used when whole stage transformer does not contain scan.
def genFinalNewRDDsForBroadcast(rddSeq: Seq[RDD[ColumnarBatch]]): Seq[RDD[ColumnarBatch]] = {
// Get the number of partitions from a non-broadcast RDD.
val nonBroadcastRDD = rddSeq.find(rdd => !rdd.isInstanceOf[BroadcastBuildSideRDD])
if (nonBroadcastRDD.isEmpty) {
throw new GlutenException("At least one RDD should not being BroadcastBuildSideRDD")
}
rddSeq.map {
case broadcastRDD: BroadcastBuildSideRDD =>
try {
broadcastRDD.getNumPartitions
broadcastRDD
} catch {
case _: Throwable =>
// Recreate the broadcast build side rdd with matched partition number.
broadcastRDD.copy(numPartitions = nonBroadcastRDD.orNull.getNumPartitions)
}
def getPartitions(index: Int): Seq[Partition] = {
columnarInputRDDs.filterNot(_.isInstanceOf[BroadcastBuildSideRDD]).map(_.partitions(index))
}

def getPartitionLength: Int = {
assert(columnarInputRDDs.nonEmpty)
val nonBroadcastRDD = columnarInputRDDs.find(!_.isInstanceOf[BroadcastBuildSideRDD])
assert(nonBroadcastRDD.isDefined)
nonBroadcastRDD.get.partitions.length
}

def getIterators(
inputColumnarRDDPartitions: Seq[Partition],
context: TaskContext): Seq[Iterator[ColumnarBatch]] = {
var index = 0
columnarInputRDDs.map {
case broadcast: BroadcastBuildSideRDD =>
BackendsApiManager.getIteratorApiInstance
.genBroadcastBuildSideIterator(broadcast.broadcasted, broadcast.broadCastContext)
case rdd =>
rdd
val it = rdd.iterator(inputColumnarRDDPartitions(index), context)
index += 1
it
}
}

override protected def withNewChildInternal(newChild: SparkPlan): WholeStageTransformer =
copy(child = newChild, materializeInput = materializeInput)(transformStageId)
}
Loading

0 comments on commit 0c59a5d

Please sign in to comment.