From 117901b928f423eff9cce7324b27c02d2bf5434f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 22 Jul 2024 16:58:08 -0400 Subject: [PATCH] Extract CoalesceBatchesStream to a struct --- .../physical-plan/src/coalesce_batches.rs | 255 ++++++++++-------- 1 file changed, 142 insertions(+), 113 deletions(-) diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index b9bdfcdee712c..7806556e70c17 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -21,7 +21,7 @@ use std::any::Any; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties, Statistics}; @@ -146,10 +146,7 @@ impl ExecutionPlan for CoalesceBatchesExec { ) -> Result { Ok(Box::pin(CoalesceBatchesStream { input: self.input.execute(partition, context)?, - schema: self.input.schema(), - target_batch_size: self.target_batch_size, - buffer: Vec::new(), - buffered_rows: 0, + coalescer: BatchCoalescer::new(self.input.schema(), self.target_batch_size), is_closed: false, baseline_metrics: BaselineMetrics::new(&self.metrics, partition), })) @@ -167,14 +164,8 @@ impl ExecutionPlan for CoalesceBatchesExec { struct CoalesceBatchesStream { /// The input plan input: SendableRecordBatchStream, - /// The input schema - schema: SchemaRef, - /// Minimum number of rows for coalesces batches - target_batch_size: usize, - /// Buffered batches - buffer: Vec, - /// Buffered row count - buffered_rows: usize, + /// Buffer for combining batches + coalescer: BatchCoalescer, /// Whether the stream has finished returning all of its data or not is_closed: bool, /// Execution metrics @@ -213,58 +204,27 @@ impl CoalesceBatchesStream { let input_batch = self.input.poll_next_unpin(cx); // records time on drop let _timer = cloned_time.timer(); - match input_batch { - Poll::Ready(x) => match x { - Some(Ok(batch)) => { - if batch.num_rows() >= self.target_batch_size - && self.buffer.is_empty() - { - return Poll::Ready(Some(Ok(batch))); - } else if batch.num_rows() == 0 { - // discard empty batches - } else { - // add to the buffered batches - self.buffered_rows += batch.num_rows(); - self.buffer.push(batch); - // check to see if we have enough batches yet - if self.buffered_rows >= self.target_batch_size { - // combine the batches and return - let batch = concat_batches( - &self.schema, - &self.buffer, - self.buffered_rows, - )?; - // reset buffer state - self.buffer.clear(); - self.buffered_rows = 0; - // return batch - return Poll::Ready(Some(Ok(batch))); - } - } - } - None => { - self.is_closed = true; - // we have reached the end of the input stream but there could still - // be buffered batches - if self.buffer.is_empty() { - return Poll::Ready(None); - } else { - // combine the batches and return - let batch = concat_batches( - &self.schema, - &self.buffer, - self.buffered_rows, - )?; - // reset buffer state - self.buffer.clear(); - self.buffered_rows = 0; - // return batch - return Poll::Ready(Some(Ok(batch))); - } + match ready!(input_batch) { + Some(result) => { + let Ok(input_batch) = result else { + return Poll::Ready(Some(result)); // pass back error + }; + // Buffer the batch and either get more input if not enough + // rows yet or output + match self.coalescer.push_batch(input_batch) { + Ok(None) => continue, + res => return Poll::Ready(res.transpose()), } - other => return Poll::Ready(other), - }, - Poll::Pending => return Poll::Pending, + } + None => { + self.is_closed = true; + // we have reached the end of the input stream but there could still + // be buffered batches + return match self.coalescer.finish() { + Ok(None) => Poll::Ready(None), + res => Poll::Ready(res.transpose()), + }; + } } } } @@ -272,7 +232,7 @@ impl CoalesceBatchesStream { impl RecordBatchStream for CoalesceBatchesStream { fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) + self.coalescer.schema() } } @@ -290,26 +250,106 @@ pub fn concat_batches( arrow::compute::concat_batches(schema, batches) } +/// Concatenating multiple record batches into larger batches +/// +/// TODO ASCII ART +/// +/// Notes: +/// +/// 1. The output is exactly the same order as the input rows +/// +/// 2. The output is a sequence of batches, with all but the last being at least +/// `target_batch_size` rows. +/// +/// 3. Eventually this may also be able to handle other optimizations such as a +/// combined filter/coalesce operation. +#[derive(Debug)] +struct BatchCoalescer { + /// The input schema + schema: SchemaRef, + /// Minimum number of rows for coalesces batches + target_batch_size: usize, + /// Buffered batches + buffer: Vec, + /// Buffered row count + buffered_rows: usize, +} + +impl BatchCoalescer { + /// Create a new BatchCoalescer that produces batches of at least `target_batch_size` rows + fn new(schema: SchemaRef, target_batch_size: usize) -> Self { + Self { + schema, + target_batch_size, + buffer: vec![], + buffered_rows: 0, + } + } + + /// Return the schema of the output batches + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + /// Add a batch to the coalescer, returning a batch if the target batch size is reached + fn push_batch(&mut self, batch: RecordBatch) -> Result> { + if batch.num_rows() >= self.target_batch_size && self.buffer.is_empty() { + return Ok(Some(batch)); + } + // discard empty batches + if batch.num_rows() == 0 { + return Ok(None); + } + // add to the buffered batches + self.buffered_rows += batch.num_rows(); + self.buffer.push(batch); + // check to see if we have enough batches yet + let batch = if self.buffered_rows >= self.target_batch_size { + // combine the batches and return + let batch = concat_batches(&self.schema, &self.buffer, self.buffered_rows)?; + // reset buffer state + self.buffer.clear(); + self.buffered_rows = 0; + // return batch + Some(batch) + } else { + None + }; + Ok(batch) + } + + /// Finish the coalescing process, returning all buffered data as a final, + /// single batch, if any + fn finish(&mut self) -> Result> { + if self.buffer.is_empty() { + Ok(None) + } else { + // combine the batches and return + let batch = concat_batches(&self.schema, &self.buffer, self.buffered_rows)?; + // reset buffer state + self.buffer.clear(); + self.buffered_rows = 0; + // return batch + Ok(Some(batch)) + } + } +} + #[cfg(test)] mod tests { use super::*; - use crate::{memory::MemoryExec, repartition::RepartitionExec, Partitioning}; - use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::UInt32Array; #[tokio::test(flavor = "multi_thread")] async fn test_concat_batches() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; - - let output_partitions = coalesce_batches(&schema, partitions, 21).await?; - assert_eq!(1, output_partitions.len()); + let Scenario { schema, batch } = uint32_scenario(); // input is 10 batches x 8 rows (80 rows) + let input = std::iter::repeat(batch).take(10); + // expected output is batches of at least 20 rows (except for the final batch) - let batches = &output_partitions[0]; + let batches = do_coalesce_batches(&schema, input, 21); assert_eq!(4, batches.len()); assert_eq!(24, batches[0].num_rows()); assert_eq!(24, batches[1].num_rows()); @@ -319,54 +359,43 @@ mod tests { Ok(()) } - fn test_schema() -> Arc { - Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) - } - - async fn coalesce_batches( + // Coalesce the batches with a BatchCoalescer function with the given input + // and target batch size returning the resulting batches + fn do_coalesce_batches( schema: &SchemaRef, - input_partitions: Vec>, + input: impl IntoIterator, target_batch_size: usize, - ) -> Result>> { + ) -> Vec { // create physical plan - let exec = MemoryExec::try_new(&input_partitions, Arc::clone(schema), None)?; - let exec = - RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(1))?; - let exec: Arc = - Arc::new(CoalesceBatchesExec::new(Arc::new(exec), target_batch_size)); - - // execute and collect results - let output_partition_count = exec.output_partitioning().partition_count(); - let mut output_partitions = Vec::with_capacity(output_partition_count); - for i in 0..output_partition_count { - // execute this *output* partition and collect all batches - let task_ctx = Arc::new(TaskContext::default()); - let mut stream = exec.execute(i, Arc::clone(&task_ctx))?; - let mut batches = vec![]; - while let Some(result) = stream.next().await { - batches.push(result?); - } - output_partitions.push(batches); + let mut coalescer = BatchCoalescer::new(Arc::clone(schema), target_batch_size); + let mut output_batches: Vec<_> = input + .into_iter() + .filter_map(|batch| coalescer.push_batch(batch).unwrap()) + .collect(); + if let Some(batch) = coalescer.finish().unwrap() { + output_batches.push(batch); } - Ok(output_partitions) + output_batches } - /// Create vector batches - fn create_vec_batches(schema: &Schema, n: usize) -> Vec { - let batch = create_batch(schema); - let mut vec = Vec::with_capacity(n); - for _ in 0..n { - vec.push(batch.clone()); - } - vec + /// Test scenario + #[derive(Debug)] + struct Scenario { + schema: Arc, + batch: RecordBatch, } - /// Create batch - fn create_batch(schema: &Schema) -> RecordBatch { - RecordBatch::try_new( - Arc::new(schema.clone()), + /// a batch of 8 rows of UInt32 + fn uint32_scenario() -> Scenario { + let schema = + Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))], ) - .unwrap() + .unwrap(); + + Scenario { schema, batch } } }