diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 038727daa7d8..b822ec2dafeb 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! CoalesceBatchesExec combines small batches into larger batches for more efficient use of -//! vectorized processing by upstream operators. +//! [`CoalesceBatchesExec`] combines small batches into larger batches. use std::any::Any; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use arrow::array::{AsArray, StringViewBuilder}; use arrow::compute::concat_batches; @@ -41,11 +40,43 @@ use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties, Statistics}; /// `CoalesceBatchesExec` combines small batches into larger batches for more -/// efficient use of vectorized processing by later operators. The operator -/// works by buffering batches until it collects `target_batch_size` rows. When -/// only a limited number of rows are necessary (specified by the `fetch` -/// parameter), the operator will stop buffering and return the final batch -/// once the number of collected rows reaches the `fetch` value. +/// efficient use of vectorized processing by later operators. +/// +/// The operator buffers batches until it collects `target_batch_size` rows and +/// then emits a single concatenated batch. When only a limited number of rows +/// are necessary (specified by the `fetch` parameter), the operator will stop +/// buffering and returns the final batch once the number of collected rows +/// reaches the `fetch` value. +/// +/// # Background +/// +/// Generally speaking, larger RecordBatches are more efficient to process than +/// smaller record batches (until the CPU cache is exceeded) because there is +/// fixed processing overhead per batch. This code concatenates multiple small +/// record batches into larger ones to amortize this overhead. +/// +/// ```text +/// ┌────────────────────┐ +/// │ RecordBatch │ +/// │ num_rows = 23 │ +/// └────────────────────┘ ┌────────────────────┐ +/// │ │ +/// ┌────────────────────┐ Coalesce │ │ +/// │ │ Batches │ │ +/// │ RecordBatch │ │ │ +/// │ num_rows = 50 │ ─ ─ ─ ─ ─ ─ ▶ │ │ +/// │ │ │ RecordBatch │ +/// │ │ │ num_rows = 106 │ +/// └────────────────────┘ │ │ +/// │ │ +/// ┌────────────────────┐ │ │ +/// │ │ │ │ +/// │ RecordBatch │ │ │ +/// │ num_rows = 33 │ └────────────────────┘ +/// │ │ +/// └────────────────────┘ +/// ``` + #[derive(Debug)] pub struct CoalesceBatchesExec { /// The input plan @@ -166,12 +197,11 @@ impl ExecutionPlan for CoalesceBatchesExec { ) -> Result { Ok(Box::pin(CoalesceBatchesStream { input: self.input.execute(partition, context)?, - schema: self.input.schema(), - target_batch_size: self.target_batch_size, - fetch: self.fetch, - buffer: Vec::new(), - buffered_rows: 0, - total_rows: 0, + coalescer: BatchCoalescer::new( + self.input.schema(), + self.target_batch_size, + self.fetch, + ), is_closed: false, baseline_metrics: BaselineMetrics::new(&self.metrics, partition), })) @@ -196,21 +226,12 @@ impl ExecutionPlan for CoalesceBatchesExec { } } +/// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details. struct CoalesceBatchesStream { /// The input plan input: SendableRecordBatchStream, - /// The input schema - schema: SchemaRef, - /// Minimum number of rows for coalesces batches - target_batch_size: usize, - /// Maximum number of rows to fetch, `None` means fetching all rows - fetch: Option, - /// Buffered batches - buffer: Vec, - /// Buffered row count - buffered_rows: usize, - /// Total number of rows returned - total_rows: usize, + /// Buffer for combining batches + coalescer: BatchCoalescer, /// Whether the stream has finished returning all of its data or not is_closed: bool, /// Execution metrics @@ -249,84 +270,178 @@ impl CoalesceBatchesStream { let input_batch = self.input.poll_next_unpin(cx); // records time on drop let _timer = cloned_time.timer(); - match input_batch { - Poll::Ready(x) => match x { - Some(Ok(batch)) => { - let batch = gc_string_view_batch(&batch); - - // Handle fetch limit: - if let Some(fetch) = self.fetch { - if self.total_rows + batch.num_rows() >= fetch { - // We have reached the fetch limit. - let remaining_rows = fetch - self.total_rows; - debug_assert!(remaining_rows > 0); - + match ready!(input_batch) { + Some(result) => { + let Ok(input_batch) = result else { + return Poll::Ready(Some(result)); // pass back error + }; + // Buffer the batch and either get more input if not enough + // rows yet or output + match self.coalescer.push_batch(input_batch) { + Ok(None) => continue, + res => { + if self.coalescer.limit_reached() { self.is_closed = true; - self.total_rows = fetch; - // Trim the batch and add to buffered batches: - let batch = batch.slice(0, remaining_rows); - self.buffered_rows += batch.num_rows(); - self.buffer.push(batch); - // Combine buffered batches: - let batch = concat_batches(&self.schema, &self.buffer)?; - // Reset the buffer state and return final batch: - self.buffer.clear(); - self.buffered_rows = 0; - return Poll::Ready(Some(Ok(batch))); - } - } - self.total_rows += batch.num_rows(); - - if batch.num_rows() >= self.target_batch_size - && self.buffer.is_empty() - { - return Poll::Ready(Some(Ok(batch))); - } else if batch.num_rows() == 0 { - // discard empty batches - } else { - // add to the buffered batches - self.buffered_rows += batch.num_rows(); - self.buffer.push(batch); - // check to see if we have enough batches yet - if self.buffered_rows >= self.target_batch_size { - // combine the batches and return - let batch = concat_batches(&self.schema, &self.buffer)?; - // reset buffer state - self.buffer.clear(); - self.buffered_rows = 0; - // return batch - return Poll::Ready(Some(Ok(batch))); } + return Poll::Ready(res.transpose()); } } - None => { - self.is_closed = true; - // we have reached the end of the input stream but there could still - // be buffered batches - if self.buffer.is_empty() { - return Poll::Ready(None); - } else { - // combine the batches and return - let batch = concat_batches(&self.schema, &self.buffer)?; - // reset buffer state - self.buffer.clear(); - self.buffered_rows = 0; - // return batch - return Poll::Ready(Some(Ok(batch))); - } - } - other => return Poll::Ready(other), - }, - Poll::Pending => return Poll::Pending, + } + None => { + self.is_closed = true; + // we have reached the end of the input stream but there could still + // be buffered batches + return match self.coalescer.finish() { + Ok(None) => Poll::Ready(None), + res => Poll::Ready(res.transpose()), + }; + } } } } } impl RecordBatchStream for CoalesceBatchesStream { + fn schema(&self) -> SchemaRef { + self.coalescer.schema() + } +} + +/// Concatenate multiple record batches into larger batches +/// +/// See [`CoalesceBatchesExec`] for more details. +/// +/// Notes: +/// +/// 1. The output rows is the same order as the input rows +/// +/// 2. The output is a sequence of batches, with all but the last being at least +/// `target_batch_size` rows. +/// +/// 3. Eventually this may also be able to handle other optimizations such as a +/// combined filter/coalesce operation. +#[derive(Debug)] +struct BatchCoalescer { + /// The input schema + schema: SchemaRef, + /// Minimum number of rows for coalesces batches + target_batch_size: usize, + /// Total number of rows returned so far + total_rows: usize, + /// Buffered batches + buffer: Vec, + /// Buffered row count + buffered_rows: usize, + /// Maximum number of rows to fetch, `None` means fetching all rows + fetch: Option, +} + +impl BatchCoalescer { + /// Create a new `BatchCoalescer` + /// + /// # Arguments + /// - `schema` - the schema of the output batches + /// - `target_batch_size` - the minimum number of rows for each + /// output batch (until limit reached) + /// - `fetch` - the maximum number of rows to fetch, `None` means fetch all rows + fn new(schema: SchemaRef, target_batch_size: usize, fetch: Option) -> Self { + Self { + schema, + target_batch_size, + total_rows: 0, + buffer: vec![], + buffered_rows: 0, + fetch, + } + } + + /// Return the schema of the output batches fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) } + + /// Add a batch, returning a batch if the target batch size or limit is reached + fn push_batch(&mut self, batch: RecordBatch) -> Result> { + // discard empty batches + if batch.num_rows() == 0 { + return Ok(None); + } + + // past limit + if self.limit_reached() { + return Ok(None); + } + + let batch = gc_string_view_batch(&batch); + + // Handle fetch limit: + if let Some(fetch) = self.fetch { + if self.total_rows + batch.num_rows() >= fetch { + // We have reached the fetch limit. + let remaining_rows = fetch - self.total_rows; + debug_assert!(remaining_rows > 0); + self.total_rows = fetch; + // Trim the batch and add to buffered batches: + let batch = batch.slice(0, remaining_rows); + self.buffered_rows += batch.num_rows(); + self.buffer.push(batch); + // Combine buffered batches: + let batch = concat_batches(&self.schema, &self.buffer)?; + // Reset the buffer state and return final batch: + self.buffer.clear(); + self.buffered_rows = 0; + return Ok(Some(batch)); + } + } + self.total_rows += batch.num_rows(); + + // batch itself is already big enough and we have no buffered rows so + // return it directly + if batch.num_rows() >= self.target_batch_size && self.buffer.is_empty() { + return Ok(Some(batch)); + } + // add to the buffered batches + self.buffered_rows += batch.num_rows(); + self.buffer.push(batch); + // check to see if we have enough batches yet + let batch = if self.buffered_rows >= self.target_batch_size { + // combine the batches and return + let batch = concat_batches(&self.schema, &self.buffer)?; + // reset buffer state + self.buffer.clear(); + self.buffered_rows = 0; + // return batch + Some(batch) + } else { + None + }; + Ok(batch) + } + + /// Finish the coalescing process, returning all buffered data as a final, + /// single batch, if any + fn finish(&mut self) -> Result> { + if self.buffer.is_empty() { + Ok(None) + } else { + // combine the batches and return + let batch = concat_batches(&self.schema, &self.buffer)?; + // reset buffer state + self.buffer.clear(); + self.buffered_rows = 0; + // return batch + Ok(Some(batch)) + } + } + + /// returns true if there is a limit and it has been reached + pub fn limit_reached(&self) -> bool { + if let Some(fetch) = self.fetch { + self.total_rows >= fetch + } else { + false + } + } } /// Heuristically compact `StringViewArray`s to reduce memory usage, if needed @@ -400,164 +515,206 @@ fn gc_string_view_batch(batch: &RecordBatch) -> RecordBatch { #[cfg(test)] mod tests { + use super::*; use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::builder::ArrayBuilder; use arrow_array::{StringViewArray, UInt32Array}; + use std::ops::Range; - use crate::{memory::MemoryExec, repartition::RepartitionExec, Partitioning}; - - use super::*; - - #[tokio::test(flavor = "multi_thread")] - async fn test_concat_batches() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; - - let output_partitions = coalesce_batches(&schema, partitions, 21, None).await?; - assert_eq!(1, output_partitions.len()); - - // input is 10 batches x 8 rows (80 rows) - // expected output is batches of at least 20 rows (except for the final batch) - let batches = &output_partitions[0]; - assert_eq!(4, batches.len()); - assert_eq!(24, batches[0].num_rows()); - assert_eq!(24, batches[1].num_rows()); - assert_eq!(24, batches[2].num_rows()); - assert_eq!(8, batches[3].num_rows()); - - Ok(()) + #[test] + fn test_coalesce() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // expected output is batches of at least 20 rows (except for the final batch) + .with_target_batch_size(21) + .with_expected_output_sizes(vec![24, 24, 24, 8]) + .run() } - #[tokio::test] - async fn test_concat_batches_with_fetch_larger_than_input_size() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; - - let output_partitions = - coalesce_batches(&schema, partitions, 21, Some(100)).await?; - assert_eq!(1, output_partitions.len()); + #[test] + fn test_coalesce_with_fetch_larger_than_input_size() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 100 + // expected to behave the same as `test_concat_batches` + .with_target_batch_size(21) + .with_fetch(Some(100)) + .with_expected_output_sizes(vec![24, 24, 24, 8]) + .run(); + } - // input is 10 batches x 8 rows (80 rows) with fetch limit of 100 - // expected to behave the same as `test_concat_batches` - let batches = &output_partitions[0]; - assert_eq!(4, batches.len()); - assert_eq!(24, batches[0].num_rows()); - assert_eq!(24, batches[1].num_rows()); - assert_eq!(24, batches[2].num_rows()); - assert_eq!(8, batches[3].num_rows()); + #[test] + fn test_coalesce_with_fetch_less_than_input_size() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 50 + .with_target_batch_size(21) + .with_fetch(Some(50)) + .with_expected_output_sizes(vec![24, 24, 2]) + .run(); + } - Ok(()) + #[test] + fn test_coalesce_with_fetch_less_than_target_and_no_remaining_rows() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 48 + .with_target_batch_size(21) + .with_fetch(Some(48)) + .with_expected_output_sizes(vec![24, 24]) + .run(); } - #[tokio::test] - async fn test_concat_batches_with_fetch_less_than_input_size() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; + #[test] + fn test_coalesce_with_fetch_less_target_batch_size() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 10 + .with_target_batch_size(21) + .with_fetch(Some(10)) + .with_expected_output_sizes(vec![10]) + .run(); + } - let output_partitions = - coalesce_batches(&schema, partitions, 21, Some(50)).await?; - assert_eq!(1, output_partitions.len()); + #[test] + fn test_coalesce_single_large_batch_over_fetch() { + let large_batch = uint32_batch(0..100); + Test::new() + .with_batch(large_batch) + .with_target_batch_size(20) + .with_fetch(Some(7)) + .with_expected_output_sizes(vec![7]) + .run() + } + + /// Test for [`BatchCoalescer`] + /// + /// Pushes the input batches to the coalescer and verifies that the resulting + /// batches have the expected number of rows and contents. + #[derive(Debug, Clone, Default)] + struct Test { + /// Batches to feed to the coalescer. Tests must have at least one + /// schema + input_batches: Vec, + /// Expected output sizes of the resulting batches + expected_output_sizes: Vec, + /// target batch size + target_batch_size: usize, + /// Fetch (limit) + fetch: Option, + } - // input is 10 batches x 8 rows (80 rows) with fetch limit of 50 - let batches = &output_partitions[0]; - assert_eq!(3, batches.len()); - assert_eq!(24, batches[0].num_rows()); - assert_eq!(24, batches[1].num_rows()); - assert_eq!(2, batches[2].num_rows()); + impl Test { + fn new() -> Self { + Self::default() + } - Ok(()) - } + /// Set the target batch size + fn with_target_batch_size(mut self, target_batch_size: usize) -> Self { + self.target_batch_size = target_batch_size; + self + } - #[tokio::test] - async fn test_concat_batches_with_fetch_less_than_target_and_no_remaining_rows( - ) -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; + /// Set the fetch (limit) + fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } - let output_partitions = - coalesce_batches(&schema, partitions, 21, Some(48)).await?; - assert_eq!(1, output_partitions.len()); + /// Extend the input batches with `batch` + fn with_batch(mut self, batch: RecordBatch) -> Self { + self.input_batches.push(batch); + self + } - // input is 10 batches x 8 rows (80 rows) with fetch limit of 48 - let batches = &output_partitions[0]; - assert_eq!(2, batches.len()); - assert_eq!(24, batches[0].num_rows()); - assert_eq!(24, batches[1].num_rows()); + /// Extends the input batches with `batches` + fn with_batches( + mut self, + batches: impl IntoIterator, + ) -> Self { + self.input_batches.extend(batches); + self + } - Ok(()) - } + /// Extends `sizes` to expected output sizes + fn with_expected_output_sizes( + mut self, + sizes: impl IntoIterator, + ) -> Self { + self.expected_output_sizes.extend(sizes); + self + } - #[tokio::test] - async fn test_concat_batches_with_fetch_less_target_batch_size() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; + /// Runs the test -- see documentation on [`Test`] for details + fn run(self) { + let Self { + input_batches, + target_batch_size, + fetch, + expected_output_sizes, + } = self; - let output_partitions = - coalesce_batches(&schema, partitions, 21, Some(10)).await?; - assert_eq!(1, output_partitions.len()); + let schema = input_batches[0].schema(); - // input is 10 batches x 8 rows (80 rows) with fetch limit of 10 - let batches = &output_partitions[0]; - assert_eq!(1, batches.len()); - assert_eq!(10, batches[0].num_rows()); + // create a single large input batch for output comparison + let single_input_batch = concat_batches(&schema, &input_batches).unwrap(); - Ok(()) - } + let mut coalescer = BatchCoalescer::new(schema, target_batch_size, fetch); - fn test_schema() -> Arc { - Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) - } + let mut output_batches = vec![]; + for batch in input_batches { + if let Some(batch) = coalescer.push_batch(batch).unwrap() { + output_batches.push(batch); + } + } + if let Some(batch) = coalescer.finish().unwrap() { + output_batches.push(batch); + } - async fn coalesce_batches( - schema: &SchemaRef, - input_partitions: Vec>, - target_batch_size: usize, - fetch: Option, - ) -> Result>> { - // create physical plan - let exec = MemoryExec::try_new(&input_partitions, Arc::clone(schema), None)?; - let exec = - RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(1))?; - let exec: Arc = Arc::new( - CoalesceBatchesExec::new(Arc::new(exec), target_batch_size).with_fetch(fetch), - ); - - // execute and collect results - let output_partition_count = exec.output_partitioning().partition_count(); - let mut output_partitions = Vec::with_capacity(output_partition_count); - for i in 0..output_partition_count { - // execute this *output* partition and collect all batches - let task_ctx = Arc::new(TaskContext::default()); - let mut stream = exec.execute(i, Arc::clone(&task_ctx))?; - let mut batches = vec![]; - while let Some(result) = stream.next().await { - batches.push(result?); + // make sure we got the expected number of output batches and content + let mut starting_idx = 0; + assert_eq!(expected_output_sizes.len(), output_batches.len()); + for (i, (expected_size, batch)) in + expected_output_sizes.iter().zip(output_batches).enumerate() + { + assert_eq!( + *expected_size, + batch.num_rows(), + "Unexpected number of rows in Batch {i}" + ); + + // compare the contents of the batch (using `==` compares the + // underlying memory layout too) + let expected_batch = + single_input_batch.slice(starting_idx, *expected_size); + let batch_strings = batch_to_pretty_strings(&batch); + let expected_batch_strings = batch_to_pretty_strings(&expected_batch); + let batch_strings = batch_strings.lines().collect::>(); + let expected_batch_strings = + expected_batch_strings.lines().collect::>(); + assert_eq!( + expected_batch_strings, batch_strings, + "Unexpected content in Batch {i}:\ + \n\nExpected:\n{expected_batch_strings:#?}\n\nActual:\n{batch_strings:#?}" + ); + starting_idx += *expected_size; } - output_partitions.push(batches); } - Ok(output_partitions) } - /// Create vector batches - fn create_vec_batches(schema: &Schema, n: usize) -> Vec { - let batch = create_batch(schema); - let mut vec = Vec::with_capacity(n); - for _ in 0..n { - vec.push(batch.clone()); - } - vec - } + /// Return a batch of UInt32 with the specified range + fn uint32_batch(range: Range) -> RecordBatch { + let schema = + Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])); - /// Create batch - fn create_batch(schema: &Schema) -> RecordBatch { RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))], + Arc::clone(&schema), + vec![Arc::new(UInt32Array::from_iter_values(range))], ) .unwrap() } @@ -656,4 +813,9 @@ mod tests { } } } + fn batch_to_pretty_strings(batch: &RecordBatch) -> String { + arrow::util::pretty::pretty_format_batches(&[batch.clone()]) + .unwrap() + .to_string() + } }