diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 1153b559a..de49fdfb0 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -227,7 +227,9 @@ object CometConf { val COMET_COLUMNAR_SHUFFLE_BATCH_SIZE: ConfigEntry[Int] = conf("spark.comet.columnar.shuffle.batch.size") .internal() - .doc("Batch size when writing out sorted spill files on the native side.") + .doc("Batch size when writing out sorted spill files on the native side. Note that " + + "this should not be larger than batch size (i.e., `spark.comet.batchSize`). Otherwise " + + "it will produce larger batches than expected in the native operator after shuffle.") .intConf .createWithDefault(8192) diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala index c17c5bce9..e8dba93e7 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala @@ -21,22 +21,14 @@ package org.apache.spark.sql.comet.execution.shuffle import java.nio.channels.ReadableByteChannel -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.CometConf -import org.apache.comet.vector.{NativeUtil, StreamReader} +import org.apache.comet.vector.StreamReader class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] { - private val nativeUtil = new NativeUtil - - private val maxBatchSize = CometConf.COMET_BATCH_SIZE.get(SQLConf.get) - private val reader = StreamReader(channel) - private var currentIdx = -1 private var batch = nextBatch() - private var previousBatch: ColumnarBatch = null private var currentBatch: ColumnarBatch = null override def hasNext: Boolean = { @@ -57,40 +49,20 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna } val nextBatch = batch.get - val batchRows = nextBatch.numRows() - val numRows = Math.min(batchRows - currentIdx, maxBatchSize) - // Release the previous sliced batch. + // Release the previous batch. // If it is not released, when closing the reader, arrow library will complain about // memory leak. if (currentBatch != null) { - // Close plain arrays in the previous sliced batch. - // The dictionary arrays will be closed when closing the entire batch. currentBatch.close() } - currentBatch = nativeUtil.takeRows(nextBatch, currentIdx, numRows) - currentIdx += numRows - - if (currentIdx == batchRows) { - // We cannot close the batch here, because if there is dictionary array in the batch, - // the dictionary array will be closed immediately, and the returned sliced batch will - // be invalid. - previousBatch = batch.get - - batch = None - currentIdx = -1 - } - + currentBatch = nextBatch + batch = None currentBatch } private def nextBatch(): Option[ColumnarBatch] = { - if (previousBatch != null) { - previousBatch.close() - previousBatch = null - } - currentIdx = 0 reader.nextBatch() } @@ -98,6 +70,7 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna synchronized { if (currentBatch != null) { currentBatch.close() + currentBatch = null } reader.close() }