diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 87697309f..78477486b 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -624,15 +624,15 @@ class CometSparkSessionExtensions val newOp = transform1(w) newOp match { case Some(nativeOp) => - val cometOp = - CometWindowExec( - w, - w.output, - w.windowExpression, - w.partitionSpec, - w.orderSpec, - w.child) - CometSinkPlaceHolder(nativeOp, w, cometOp) + CometWindowExec( + nativeOp, + w, + w.output, + w.windowExpression, + w.partitionSpec, + w.orderSpec, + w.child, + SerializedPlan(None)) case None => w } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala index 4c10a8abb..bb1871705 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala @@ -19,20 +19,15 @@ package org.apache.spark.sql.comet -import scala.collection.JavaConverters.asJavaIterableConverter - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression, SortOrder, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.comet.CometWindowExec.getNativePlan import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec.{METRIC_NATIVE_TIME_DESCRIPTION, METRIC_NATIVE_TIME_NAME} -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.serde.OperatorOuterClass +import com.google.common.base.Objects + import org.apache.comet.serde.OperatorOuterClass.Operator -import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType, windowExprToProto} /** * Comet physical plan node for Spark `WindowsExec`. @@ -42,14 +37,15 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType, wi * executions separated by a Comet shuffle exchange. */ case class CometWindowExec( + override val nativeOp: Operator, override val originalPlan: SparkPlan, override val output: Seq[Attribute], windowExpression: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], - child: SparkPlan) - extends CometExec - with UnaryExecNode { + child: SparkPlan, + override val serializedPlanOpt: SerializedPlan) + extends CometUnaryExec { override def nodeName: String = "CometWindowExec" @@ -65,18 +61,6 @@ case class CometWindowExec( sparkContext, "number of partitions")) ++ readMetrics ++ writeMetrics - override def supportsColumnar: Boolean = true - - protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val childRDD = child.executeColumnar() - - childRDD.mapPartitionsInternal { iter => - CometExec.getCometIterator( - Seq(iter), - getNativePlan(output, windowExpression, partitionSpec, orderSpec, child).get) - } - } - override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputPartitioning: Partitioning = child.outputPartitioning @@ -84,52 +68,20 @@ case class CometWindowExec( protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) -} - -object CometWindowExec { - def getNativePlan( - outputAttributes: Seq[Attribute], - windowExpression: Seq[NamedExpression], - partitionSpec: Seq[Expression], - orderSpec: Seq[SortOrder], - child: SparkPlan): Option[Operator] = { - - val orderSpecs = orderSpec.map(exprToProto(_, child.output)) - val partitionSpecs = partitionSpec.map(exprToProto(_, child.output)) - val scanBuilder = OperatorOuterClass.Scan.newBuilder() - val scanOpBuilder = OperatorOuterClass.Operator.newBuilder() - - val scanTypes = outputAttributes.flatten { attr => - serializeDataType(attr.dataType) + override def stringArgs: Iterator[Any] = + Iterator(output, windowExpression, partitionSpec, orderSpec, child) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometWindowExec => + this.windowExpression == other.windowExpression && this.child == other.child && + this.partitionSpec == other.partitionSpec && this.orderSpec == other.orderSpec && + this.serializedPlanOpt == other.serializedPlanOpt + case _ => + false } - - val windowExprs = windowExpression.map(w => - windowExprToProto( - w.asInstanceOf[Alias].child.asInstanceOf[WindowExpression], - outputAttributes)) - - val windowBuilder = OperatorOuterClass.Window - .newBuilder() - - if (windowExprs.forall(_.isDefined)) { - windowBuilder - .addAllWindowExpr(windowExprs.map(_.get).asJava) - - if (orderSpecs.forall(_.isDefined)) { - windowBuilder.addAllOrderByList(orderSpecs.map(_.get).asJava) - } - - if (partitionSpecs.forall(_.isDefined)) { - windowBuilder.addAllPartitionByList(partitionSpecs.map(_.get).asJava) - } - - scanBuilder.addAllFields(scanTypes.asJava) - - val opBuilder = OperatorOuterClass.Operator - .newBuilder() - .addChildren(scanOpBuilder.setScan(scanBuilder)) - - Some(opBuilder.setWindow(windowBuilder).build()) - } else None } + + override def hashCode(): Int = + Objects.hashCode(windowExpression, partitionSpec, orderSpec, child) } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 5cbc4975e..56da81cbf 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -63,6 +63,29 @@ class CometExecSuite extends CometTestBase { } } + test("Native window operator should be CometUnaryExec") { + withTempView("testData") { + sql(""" + |CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES + |(null, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "a"), + |(1, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "a"), + |(1, 2L, 2.5D, date("2017-08-02"), timestamp_seconds(1502000000), "a"), + |(2, 2147483650L, 100.001D, date("2020-12-31"), timestamp_seconds(1609372800), "a"), + |(1, null, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), "b"), + |(2, 3L, 3.3D, date("2017-08-03"), timestamp_seconds(1503000000), "b"), + |(3, 2147483650L, 100.001D, date("2020-12-31"), timestamp_seconds(1609372800), "b"), + |(null, null, null, null, null, null), + |(3, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), null) + |AS testData(val, val_long, val_double, val_date, val_timestamp, cate) + |""".stripMargin) + val df = sql(""" + |SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW) + |FROM testData ORDER BY cate, val + |""".stripMargin) + checkSparkAnswer(df) + } + } + test("subquery execution under CometTakeOrderedAndProjectExec should not fail") { assume(isSpark35Plus, "SPARK-45584 is fixed in Spark 3.5+")