Skip to content

Commit

Permalink
fix: Native window operator should be CometUnaryExec
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Aug 4, 2024
1 parent 2d95fea commit 60c7cea
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -621,15 +621,15 @@ class CometSparkSessionExtensions
val newOp = transform1(w)
newOp match {
case Some(nativeOp) =>
val cometOp =
CometWindowExec(
w,
w.output,
w.windowExpression,
w.partitionSpec,
w.orderSpec,
w.child)
CometSinkPlaceHolder(nativeOp, w, cometOp)
CometWindowExec(
nativeOp,
w,
w.output,
w.windowExpression,
w.partitionSpec,
w.orderSpec,
w.child,
SerializedPlan(None))
case None =>
w
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,15 @@

package org.apache.spark.sql.comet

import scala.collection.JavaConverters.asJavaIterableConverter

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression, SortOrder, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.comet.CometWindowExec.getNativePlan
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec.{METRIC_NATIVE_TIME_DESCRIPTION, METRIC_NATIVE_TIME_NAME}
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.serde.OperatorOuterClass
import com.google.common.base.Objects

import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType, windowExprToProto}

/**
* Comet physical plan node for Spark `WindowsExec`.
Expand All @@ -42,14 +37,15 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType, wi
* executions separated by a Comet shuffle exchange.
*/
case class CometWindowExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
override val output: Seq[Attribute],
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan)
extends CometExec
with UnaryExecNode {
child: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometUnaryExec {

override def nodeName: String = "CometWindowExec"

Expand All @@ -65,71 +61,27 @@ case class CometWindowExec(
sparkContext,
"number of partitions")) ++ readMetrics ++ writeMetrics

override def supportsColumnar: Boolean = true

protected override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val childRDD = child.executeColumnar()

childRDD.mapPartitionsInternal { iter =>
CometExec.getCometIterator(
Seq(iter),
getNativePlan(output, windowExpression, partitionSpec, orderSpec, child).get)
}
}

override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def outputPartitioning: Partitioning = child.outputPartitioning

protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
this.copy(child = newChild)

}

object CometWindowExec {
def getNativePlan(
outputAttributes: Seq[Attribute],
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan): Option[Operator] = {

val orderSpecs = orderSpec.map(exprToProto(_, child.output))
val partitionSpecs = partitionSpec.map(exprToProto(_, child.output))
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
val scanOpBuilder = OperatorOuterClass.Operator.newBuilder()

val scanTypes = outputAttributes.flatten { attr =>
serializeDataType(attr.dataType)
override def stringArgs: Iterator[Any] =
Iterator(output, windowExpression, partitionSpec, orderSpec, child)

override def equals(obj: Any): Boolean = {
obj match {
case other: CometWindowExec =>
this.windowExpression == other.windowExpression && this.child == other.child &&
this.partitionSpec == other.partitionSpec && this.orderSpec == other.orderSpec &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}

val windowExprs = windowExpression.map(w =>
windowExprToProto(
w.asInstanceOf[Alias].child.asInstanceOf[WindowExpression],
outputAttributes))

val windowBuilder = OperatorOuterClass.Window
.newBuilder()

if (windowExprs.forall(_.isDefined)) {
windowBuilder
.addAllWindowExpr(windowExprs.map(_.get).asJava)

if (orderSpecs.forall(_.isDefined)) {
windowBuilder.addAllOrderByList(orderSpecs.map(_.get).asJava)
}

if (partitionSpecs.forall(_.isDefined)) {
windowBuilder.addAllPartitionByList(partitionSpecs.map(_.get).asJava)
}

scanBuilder.addAllFields(scanTypes.asJava)

val opBuilder = OperatorOuterClass.Operator
.newBuilder()
.addChildren(scanOpBuilder.setScan(scanBuilder))

Some(opBuilder.setWindow(windowBuilder).build())
} else None
}

override def hashCode(): Int =
Objects.hashCode(windowExpression, partitionSpec, orderSpec, child)
}
23 changes: 23 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,29 @@ class CometExecSuite extends CometTestBase {
}
}

test("Native window operator should be CometUnaryExec") {
withTempView("testData") {
sql("""
|CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES
|(null, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "a"),
|(1, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "a"),
|(1, 2L, 2.5D, date("2017-08-02"), timestamp_seconds(1502000000), "a"),
|(2, 2147483650L, 100.001D, date("2020-12-31"), timestamp_seconds(1609372800), "a"),
|(1, null, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "b"),
|(2, 3L, 3.3D, date("2017-08-03"), timestamp_seconds(1503000000), "b"),
|(3, 2147483650L, 100.001D, date("2020-12-31"), timestamp_seconds(1609372800), "b"),
|(null, null, null, null, null, null),
|(3, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), null)
|AS testData(val, val_long, val_double, val_date, val_timestamp, cate)
|""".stripMargin)
val df = sql("""
|SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW)
|FROM testData ORDER BY cate, val
|""".stripMargin)
checkSparkAnswer(df)
}
}

test("subquery execution under CometTakeOrderedAndProjectExec should not fail") {
assume(isSpark35Plus, "SPARK-45584 is fixed in Spark 3.5+")

Expand Down

0 comments on commit 60c7cea

Please sign in to comment.