Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Use RDD partition index #1112

Merged
merged 4 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@ import org.apache.comet.vector.NativeUtil
* The input iterators producing sequence of batches of Arrow Arrays.
* @param protobufQueryPlan
* The serialized bytes of Spark execution plan.
* @param numParts
* The number of partitions.
* @param partitionIndex
* The index of the partition.
*/
class CometExecIterator(
val id: Long,
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
protobufQueryPlan: Array[Byte],
nativeMetrics: CometMetricNode)
nativeMetrics: CometMetricNode,
numParts: Int,
partitionIndex: Int)
extends Iterator[ColumnarBatch] {

private val nativeLib = new Native()
Expand Down Expand Up @@ -92,11 +98,13 @@ class CometExecIterator(
}

def getNextBatch(): Option[ColumnarBatch] = {
assert(partitionIndex >= 0 && partitionIndex < numParts)

nativeUtil.getNextBatch(
numOutputCols,
(arrayAddrs, schemaAddrs) => {
val ctx = TaskContext.get()
nativeLib.executePlan(ctx.stageId(), ctx.partitionId(), plan, arrayAddrs, schemaAddrs)
nativeLib.executePlan(ctx.stageId(), partitionIndex, plan, arrayAddrs, schemaAddrs)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ object CometExecUtils {
childPlan: RDD[ColumnarBatch],
outputAttribute: Seq[Attribute],
limit: Int): RDD[ColumnarBatch] = {
childPlan.mapPartitionsInternal { iter =>
val numParts = childPlan.getNumPartitions
childPlan.mapPartitionsWithIndexInternal { case (idx, iter) =>
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp)
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ case class CometTakeOrderedAndProjectExec(
val localTopK = if (orderingSatisfies) {
CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit)
} else {
childRDD.mapPartitionsInternal { iter =>
val numParts = childRDD.getNumPartitions
childRDD.mapPartitionsWithIndexInternal { case (idx, iter) =>
val topK =
CometExecUtils
.getTopKNativePlan(child.output, sortOrder, child, limit)
.get
CometExec.getCometIterator(Seq(iter), child.output.length, topK)
CometExec.getCometIterator(Seq(iter), child.output.length, topK, numParts, idx)
}
}

Expand All @@ -102,7 +103,7 @@ case class CometTakeOrderedAndProjectExec(
val topKAndProjection = CometExecUtils
.getProjectionNativePlan(projectList, child.output, sortOrder, child, limit)
.get
val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection)
val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection, 1, 0)
setSubqueries(it.id, this)

Option(TaskContext.get()).foreach { context =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,20 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
*/
private[spark] class ZippedPartitionsRDD(
sc: SparkContext,
var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch],
var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch],
var zipRdds: Seq[RDD[ColumnarBatch]],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[ColumnarBatch](sc, zipRdds, preservesPartitioning) {

// We need to get the number of partitions in `compute` but `getNumPartitions` is not available
// on the executors. So we need to capture it here.
private val numParts: Int = this.getNumPartitions

override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
val iterators =
zipRdds.zipWithIndex.map(pair => pair._1.iterator(partitions(pair._2), context))
f(iterators)
f(iterators, numParts, s.index)
}

override def clearDependencies(): Unit = {
Expand All @@ -52,7 +56,8 @@ private[spark] class ZippedPartitionsRDD(

object ZippedPartitionsRDD {
def apply(sc: SparkContext, rdds: Seq[RDD[ColumnarBatch]])(
f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): RDD[ColumnarBatch] =
f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
: RDD[ColumnarBatch] =
withScope(sc) {
new ZippedPartitionsRDD(sc, f, rdds)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,14 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
outputPartitioning: Partitioning,
serializer: Serializer,
metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
val numParts = rdd.getNumPartitions
val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch](
rdd.map(
(0, _)
), // adding fake partitionId that is always 0 because ShuffleDependency requires it
serializer = serializer,
shuffleWriterProcessor =
new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics),
new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics, numParts),
shuffleType = CometNativeShuffle,
partitioner = new Partitioner {
override def numPartitions: Int = outputPartitioning.numPartitions
Expand Down Expand Up @@ -449,7 +450,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
class CometShuffleWriteProcessor(
outputPartitioning: Partitioning,
outputAttributes: Seq[Attribute],
metrics: Map[String, SQLMetric])
metrics: Map[String, SQLMetric],
numParts: Int)
extends ShimCometShuffleWriteProcessor {

private val OFFSET_LENGTH = 8
Expand Down Expand Up @@ -499,7 +501,9 @@ class CometShuffleWriteProcessor(
Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
outputAttributes.length,
nativePlan,
nativeMetrics)
nativeMetrics,
numParts,
context.partitionId())

while (cometIter.hasNext) {
cometIter.next()
Expand Down
36 changes: 29 additions & 7 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,37 @@ object CometExec {
def getCometIterator(
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
nativePlan: Operator): CometExecIterator = {
getCometIterator(inputs, numOutputCols, nativePlan, CometMetricNode(Map.empty))
nativePlan: Operator,
numParts: Int,
partitionIdx: Int): CometExecIterator = {
getCometIterator(
inputs,
numOutputCols,
nativePlan,
CometMetricNode(Map.empty),
numParts,
partitionIdx)
}

def getCometIterator(
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
nativePlan: Operator,
nativeMetrics: CometMetricNode): CometExecIterator = {
nativeMetrics: CometMetricNode,
numParts: Int,
partitionIdx: Int): CometExecIterator = {
val outputStream = new ByteArrayOutputStream()
nativePlan.writeTo(outputStream)
outputStream.close()
val bytes = outputStream.toByteArray
new CometExecIterator(newIterId, inputs, numOutputCols, bytes, nativeMetrics)
new CometExecIterator(
newIterId,
inputs,
numOutputCols,
bytes,
nativeMetrics,
numParts,
partitionIdx)
}

/**
Expand Down Expand Up @@ -214,13 +231,18 @@ abstract class CometNativeExec extends CometExec {
// TODO: support native metrics for all operators.
val nativeMetrics = CometMetricNode.fromCometPlan(this)

def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): CometExecIterator = {
def createCometExecIter(
inputs: Seq[Iterator[ColumnarBatch]],
numParts: Int,
partitionIndex: Int): CometExecIterator = {
val it = new CometExecIterator(
CometExec.newIterId,
inputs,
output.length,
serializedPlanCopy,
nativeMetrics)
nativeMetrics,
numParts,
partitionIndex)

setSubqueries(it.id, this)

Expand Down Expand Up @@ -295,7 +317,7 @@ abstract class CometNativeExec extends CometExec {
throw new CometRuntimeException(s"No input for CometNativeExec:\n $this")
}

ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter(_))
ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter)
}
}

Expand Down
4 changes: 3 additions & 1 deletion spark/src/test/scala/org/apache/comet/CometNativeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class CometNativeSuite extends CometTestBase {
override def next(): ColumnarBatch = throw new NullPointerException()
}),
1,
limitOp)
limitOp,
1,
0)
cometIter.next()
cometIter.close()
value
Expand Down
Loading