Skip to content

Commit

Permalink
[SPARK-46452][SQL] Add a new API in DataWriter to write an iterator o…
Browse files Browse the repository at this point in the history
…f records

### What changes were proposed in this pull request?

This PR proposes to add a new method in `DataWriter` that supports writing an iterator of records:
```java
void writeAll(Iterator<T> records) throws IOException
```

### Why are the changes needed?

To make the API more flexible and support more use cases (e.g Python data sources). See #43791

### Does this PR introduce _any_ user-facing change?

Yes. This PR introduces a new method in `DataWriter`.

### How was this patch tested?

Existing unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #44410 from allisonwang-db/spark-46452-dsv2-write-all.

Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
allisonwang-db authored and cloud-fan committed Dec 20, 2023
1 parent cf54e8f commit c9cfaac
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.Closeable;
import java.io.IOException;
import java.util.Iterator;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.metric.CustomTaskMetric;
Expand Down Expand Up @@ -74,6 +75,23 @@ public interface DataWriter<T> extends Closeable {
*/
void write(T record) throws IOException;

/**
* Writes all records provided by the given iterator. By default, it calls the {@link #write}
* method for each record in the iterator.
* <p>
* If this method fails (by throwing an exception), {@link #abort()} will be called and this
* data writer is considered to have been failed.
*
* @throws IOException if failure happens during disk/network IO like writing files.
*
* @since 4.0.0
*/
default void writeAll(Iterator<T> records) throws IOException {
while (records.hasNext()) {
write(records.next());
}
}

/**
* Commits this writer after all records are written successfully, returns a commit message which
* will be sent back to driver side and passed to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode {

trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serializable {

protected def write(writer: W, row: InternalRow): Unit
protected def write(writer: W, iter: java.util.Iterator[InternalRow]): Unit

def run(
writerFactory: DataWriterFactory,
Expand All @@ -436,19 +436,11 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serial
val attemptId = context.attemptNumber()
val dataWriter = writerFactory.createWriter(partId, taskId).asInstanceOf[W]

var count = 0L
val iterWithMetrics = IteratorWithMetrics(iter, dataWriter, customMetrics)

// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
while (iter.hasNext) {
if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
CustomMetrics.updateMetrics(
dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics)
}

// Count is here.
count += 1
write(dataWriter, iter.next())
}
write(dataWriter, iterWithMetrics)

CustomMetrics.updateMetrics(
dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics)
Expand Down Expand Up @@ -476,7 +468,7 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serial
logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId, " +
s"stage $stageId.$stageAttempt)")

DataWritingSparkTaskResult(count, msg)
DataWritingSparkTaskResult(iterWithMetrics.count, msg)

})(catchBlock = {
// If there is an error, abort this writer
Expand All @@ -489,11 +481,30 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with Serial
dataWriter.close()
})
}

private case class IteratorWithMetrics(
iter: Iterator[InternalRow],
dataWriter: W,
customMetrics: Map[String, SQLMetric]) extends java.util.Iterator[InternalRow] {
var count = 0L

override def hasNext: Boolean = iter.hasNext

override def next(): InternalRow = {
if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
CustomMetrics.updateMetrics(
dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics)
}
count += 1
iter.next()
}
}
}

object DataWritingSparkTask extends WritingSparkTask[DataWriter[InternalRow]] {
override protected def write(writer: DataWriter[InternalRow], row: InternalRow): Unit = {
writer.write(row)
override protected def write(
writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = {
writer.writeAll(iter)
}
}

Expand All @@ -503,25 +514,29 @@ case class DeltaWritingSparkTask(
private lazy val rowProjection = projections.rowProjection.orNull
private lazy val rowIdProjection = projections.rowIdProjection

override protected def write(writer: DeltaWriter[InternalRow], row: InternalRow): Unit = {
val operation = row.getInt(0)
override protected def write(
writer: DeltaWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = {
while (iter.hasNext) {
val row = iter.next()
val operation = row.getInt(0)

operation match {
case DELETE_OPERATION =>
rowIdProjection.project(row)
writer.delete(null, rowIdProjection)
operation match {
case DELETE_OPERATION =>
rowIdProjection.project(row)
writer.delete(null, rowIdProjection)

case UPDATE_OPERATION =>
rowProjection.project(row)
rowIdProjection.project(row)
writer.update(null, rowIdProjection, rowProjection)
case UPDATE_OPERATION =>
rowProjection.project(row)
rowIdProjection.project(row)
writer.update(null, rowIdProjection, rowProjection)

case INSERT_OPERATION =>
rowProjection.project(row)
writer.insert(rowProjection)
case INSERT_OPERATION =>
rowProjection.project(row)
writer.insert(rowProjection)

case other =>
throw new SparkException(s"Unexpected operation ID: $other")
case other =>
throw new SparkException(s"Unexpected operation ID: $other")
}
}
}
}
Expand All @@ -533,27 +548,31 @@ case class DeltaWithMetadataWritingSparkTask(
private lazy val rowIdProjection = projections.rowIdProjection
private lazy val metadataProjection = projections.metadataProjection.orNull

override protected def write(writer: DeltaWriter[InternalRow], row: InternalRow): Unit = {
val operation = row.getInt(0)

operation match {
case DELETE_OPERATION =>
rowIdProjection.project(row)
metadataProjection.project(row)
writer.delete(metadataProjection, rowIdProjection)

case UPDATE_OPERATION =>
rowProjection.project(row)
rowIdProjection.project(row)
metadataProjection.project(row)
writer.update(metadataProjection, rowIdProjection, rowProjection)

case INSERT_OPERATION =>
rowProjection.project(row)
writer.insert(rowProjection)

case other =>
throw new SparkException(s"Unexpected operation ID: $other")
override protected def write(
writer: DeltaWriter[InternalRow], iter: java.util.Iterator[InternalRow]): Unit = {
while (iter.hasNext) {
val row = iter.next()
val operation = row.getInt(0)

operation match {
case DELETE_OPERATION =>
rowIdProjection.project(row)
metadataProjection.project(row)
writer.delete(metadataProjection, rowIdProjection)

case UPDATE_OPERATION =>
rowProjection.project(row)
rowIdProjection.project(row)
metadataProjection.project(row)
writer.update(metadataProjection, rowIdProjection, rowProjection)

case INSERT_OPERATION =>
rowProjection.project(row)
writer.insert(rowProjection)

case other =>
throw new SparkException(s"Unexpected operation ID: $other")
}
}
}
}
Expand Down

0 comments on commit c9cfaac

Please sign in to comment.