From bbd6b3f49974c386a499712ea4583f447e8bcda1 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 8 Dec 2024 15:12:07 -0800 Subject: [PATCH] fix --- .../apache/spark/sql/comet/CometColumnarToRowExec.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala index b50f0cd07..1322129e0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{CodegenSupport, ColumnarToRowTransition, SparkPlan} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.WritableColumnVector +import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector, WritableColumnVector} import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils @@ -157,7 +157,9 @@ case class CometColumnarToRowExec(child: SparkPlan) } val writableColumnVectorClz = classOf[WritableColumnVector].getName + val constantColumnVectorClz = classOf[ConstantColumnVector].getName + // scalastyle:off line.size.limit s""" |if ($batch == null) { | $nextBatchFuncName(); @@ -174,7 +176,7 @@ case class CometColumnarToRowExec(child: SparkPlan) | | // Comet fix for SPARK-50235 | for (int i = 0; i < ${colVars.length}; i++) { - | if (!($batch.column(i) instanceof $writableColumnVectorClz)) { + | if (!($batch.column(i) instanceof $writableColumnVectorClz || $batch.column(i) instanceof $constantColumnVectorClz)) { | $batch.column(i).close(); | } | } @@ -187,6 +189,7 @@ case class CometColumnarToRowExec(child: SparkPlan) | $batch.close(); |} """.stripMargin + // scalastyle:on line.size.limit } override def inputRDDs(): Seq[RDD[InternalRow]] = {