Skip to content

Commit

Permalink
Add CometColumnarToRowExec
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Dec 8, 2024
1 parent f206260 commit cdc8f10
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,10 @@ class CometSparkSessionExtensions
case class EliminateRedundantTransitions(session: SparkSession) extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
val eliminatedPlan = plan transformUp {
case ColumnarToRowExec(child) => CometColumnarToRowExec(child)
case ColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) => sparkToColumnar.child
case CometColumnarToRowExec(sparkToColumnar: CometSparkToColumnarExec) =>
sparkToColumnar.child
case CometSparkToColumnarExec(child: CometSparkToColumnarExec) => child
// Spark adds `RowToColumnar` under Comet columnar shuffle. But it's redundant as the
// shuffle takes row-based input.
Expand All @@ -1100,6 +1103,8 @@ class CometSparkSessionExtensions
eliminatedPlan match {
case ColumnarToRowExec(child: CometCollectLimitExec) =>
child
case CometColumnarToRowExec(child: CometCollectLimitExec) =>
child
case other =>
other
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.comet

import scala.collection.JavaConverters._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{CodegenSupport, ColumnarToRowTransition, SparkPlan}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.vectorized.WritableColumnVector
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.util.Utils

/**
* Copied from Spark `ColumnarToRowExec`. Comet needs the fix for SPARK-50235 but cannot wait for
* the fix to be released in Spark versions. We copy the implementation here to apply the fix.
*/
case class CometColumnarToRowExec(child: SparkPlan)
extends ColumnarToRowTransition
with CodegenSupport {
// supportsColumnar requires to be only called on driver side, see also SPARK-37779.
assert(Utils.isInRunningSparkTask || child.supportsColumnar)

override def output: Seq[Attribute] = child.output

override def outputPartitioning: Partitioning = child.outputPartitioning

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

// `ColumnarToRowExec` processes the input RDD directly, which is kind of a leaf node in the
// codegen stage and needs to do the limit check.
protected override def canCheckLimitNotReached: Boolean = true

override lazy val metrics: Map[String, SQLMetric] = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"))

override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val numInputBatches = longMetric("numInputBatches")
// This avoids calling `output` in the RDD closure, so that we don't need to include the entire
// plan (this) in the closure.
val localOutput = this.output
child.executeColumnar().mapPartitionsInternal { batches =>
val toUnsafe = UnsafeProjection.create(localOutput, localOutput)
batches.flatMap { batch =>
numInputBatches += 1
numOutputRows += batch.numRows()
batch.rowIterator().asScala.map(toUnsafe)
}
}
}

/**
* Generate [[ColumnVector]] expressions for our parent to consume as rows. This is called once
* per [[ColumnVector]] in the batch.
*/
private def genCodeColumnVector(
ctx: CodegenContext,
columnVar: String,
ordinal: String,
dataType: DataType,
nullable: Boolean): ExprCode = {
val javaType = CodeGenerator.javaType(dataType)
val value = CodeGenerator.getValueFromVector(columnVar, dataType, ordinal)
val isNullVar = if (nullable) {
JavaCode.isNullVariable(ctx.freshName("isNull"))
} else {
FalseLiteral
}
val valueVar = ctx.freshName("value")
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
val code = code"${ctx.registerComment(str)}" + (if (nullable) {
code"""
boolean $isNullVar = $columnVar.isNullAt($ordinal);
$javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value);
"""
} else {
code"$javaType $valueVar = $value;"
})
ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType))
}

/**
* Produce code to process the input iterator as [[ColumnarBatch]]es. This produces an
* [[org.apache.spark.sql.catalyst.expressions.UnsafeRow]] for each row in each batch.
*/
override protected def doProduce(ctx: CodegenContext): String = {
// PhysicalRDD always just has one input
val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];")

// metrics
val numOutputRows = metricTerm(ctx, "numOutputRows")
val numInputBatches = metricTerm(ctx, "numInputBatches")

val columnarBatchClz = classOf[ColumnarBatch].getName
val batch = ctx.addMutableState(columnarBatchClz, "batch")

val idx = ctx.addMutableState(CodeGenerator.JAVA_INT, "batchIdx") // init as batchIdx = 0
val columnVectorClzs =
child.vectorTypes.getOrElse(Seq.fill(output.indices.size)(classOf[ColumnVector].getName))
val (colVars, columnAssigns) = columnVectorClzs.zipWithIndex.map {
case (columnVectorClz, i) =>
val name = ctx.addMutableState(columnVectorClz, s"colInstance$i")
(name, s"$name = ($columnVectorClz) $batch.column($i);")
}.unzip

val nextBatch = ctx.freshName("nextBatch")
val nextBatchFuncName = ctx.addNewFunction(
nextBatch,
s"""
|private void $nextBatch() throws java.io.IOException {
| if ($input.hasNext()) {
| $batch = ($columnarBatchClz)$input.next();
| $numInputBatches.add(1);
| $numOutputRows.add($batch.numRows());
| $idx = 0;
| ${columnAssigns.mkString("", "\n", "\n")}
| }
|}""".stripMargin)

ctx.currentVars = null
val rowidx = ctx.freshName("rowIdx")
val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable)
}
val localIdx = ctx.freshName("localIdx")
val localEnd = ctx.freshName("localEnd")
val numRows = ctx.freshName("numRows")
val shouldStop = if (parent.needStopCheck) {
s"if (shouldStop()) { $idx = $rowidx + 1; return; }"
} else {
"// shouldStop check is eliminated"
}

val writableColumnVectorClz = classOf[WritableColumnVector].getName

s"""
|if ($batch == null) {
| $nextBatchFuncName();
|}
|while ($limitNotReachedCond $batch != null) {
| int $numRows = $batch.numRows();
| int $localEnd = $numRows - $idx;
| for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
| int $rowidx = $idx + $localIdx;
| ${consume(ctx, columnsBatchInput).trim}
| $shouldStop
| }
| $idx = $numRows;
|
| // Comet fix for SPARK-50235
| for (int i = 0; i < ${colVars.length}; i++) {
| if (!($batch.column(i) instanceof $writableColumnVectorClz)) {
| $batch.column(i).close();
| }
| }
|
| $batch = null;
| $nextBatchFuncName();
|}
|// Comet fix for SPARK-50235: clean up resources
|if ($batch != null) {
| $batch.close();
|}
""".stripMargin
}

override def inputRDDs(): Seq[RDD[InternalRow]] = {
Seq(child.executeColumnar().asInstanceOf[RDD[InternalRow]]) // Hack because of type erasure
}

override protected def withNewChildInternal(newChild: SparkPlan): CometColumnarToRowExec =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import scala.util.Random
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps
import org.apache.spark.sql.comet.CometProjectExec
import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, ProjectExec, WholeStageCodegenExec}
import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec}
import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -749,7 +749,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
val project = cometPlan
.asInstanceOf[WholeStageCodegenExec]
.child
.asInstanceOf[ColumnarToRowExec]
.asInstanceOf[CometColumnarToRowExec]
.child
.asInstanceOf[InputAdapter]
.child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.parquet.hadoop.example.ExampleParquetWriter
import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.apache.spark._
import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER}
import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder, CometSparkToColumnarExec}
import org.apache.spark.sql.comet.{CometBatchScanExec, CometBroadcastExchangeExec, CometColumnarToRowExec, CometExec, CometScanExec, CometScanWrapper, CometSinkPlaceHolder, CometSparkToColumnarExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{ColumnarToRowExec, ExtendedMode, InputAdapter, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
Expand Down Expand Up @@ -174,6 +174,7 @@ abstract class CometTestBase
wrapped.foreach {
case _: CometScanExec | _: CometBatchScanExec =>
case _: CometSinkPlaceHolder | _: CometScanWrapper =>
case _: CometColumnarToRowExec =>
case _: CometSparkToColumnarExec =>
case _: CometExec | _: CometShuffleExchangeExec =>
case _: CometBroadcastExchangeExec =>
Expand Down

0 comments on commit cdc8f10

Please sign in to comment.