Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Skipping slicing on shuffle arrays in shuffle reader #189

Merged
merged 2 commits into from
Mar 11, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,14 @@ package org.apache.spark.sql.comet.execution.shuffle

import java.nio.channels.ReadableByteChannel

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.CometConf
import org.apache.comet.vector.{NativeUtil, StreamReader}
import org.apache.comet.vector.StreamReader

class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] {

private val nativeUtil = new NativeUtil

private val maxBatchSize = CometConf.COMET_BATCH_SIZE.get(SQLConf.get)

private val reader = StreamReader(channel)
private var currentIdx = -1
private var batch = nextBatch()
private var previousBatch: ColumnarBatch = null
private var currentBatch: ColumnarBatch = null

override def hasNext: Boolean = {
Expand All @@ -57,47 +49,28 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna
}

val nextBatch = batch.get
val batchRows = nextBatch.numRows()
val numRows = Math.min(batchRows - currentIdx, maxBatchSize)

// Release the previous sliced batch.
// Release the previous batch.
// If it is not released, when closing the reader, arrow library will complain about
// memory leak.
if (currentBatch != null) {
// Close plain arrays in the previous sliced batch.
// The dictionary arrays will be closed when closing the entire batch.
currentBatch.close()
}

currentBatch = nativeUtil.takeRows(nextBatch, currentIdx, numRows)
currentIdx += numRows

if (currentIdx == batchRows) {
// We cannot close the batch here, because if there is dictionary array in the batch,
// the dictionary array will be closed immediately, and the returned sliced batch will
// be invalid.
previousBatch = batch.get

batch = None
currentIdx = -1
}

currentBatch = nextBatch
batch = None
currentBatch
}

private def nextBatch(): Option[ColumnarBatch] = {
if (previousBatch != null) {
previousBatch.close()
previousBatch = null
}
currentIdx = 0
reader.nextBatch()
}

def close(): Unit =
synchronized {
if (currentBatch != null) {
currentBatch.close()
currentBatch = null
}
reader.close()
}
Expand Down
Loading