diff --git a/datafusion/core/src/physical_plan/sorts/cascade.rs b/datafusion/core/src/physical_plan/sorts/cascade.rs index ccbf8dc8d137b..3ac0ae7d20630 100644 --- a/datafusion/core/src/physical_plan/sorts/cascade.rs +++ b/datafusion/core/src/physical_plan/sorts/cascade.rs @@ -13,6 +13,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; use futures::Stream; +use std::collections::VecDeque; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -51,10 +52,11 @@ impl SortPreservingCascadeStream { // Refer to YieldedCursorStream for where the concat would happen (TODO). let streams = Arc::new(parking_lot::Mutex::new(streams)); - let max_streams_per_merge = 2; // TODO: change this to 10, once we have tested with 2 (to force more leaf nodes) - let mut divided_streams: Vec> = - Vec::with_capacity(stream_count / max_streams_per_merge + 1); + let max_streams_per_merge = 2; // TODO: change this to 10, once we have tested with 2 (to force more cascade levels) + let mut divided_streams: VecDeque> = + VecDeque::with_capacity(stream_count / max_streams_per_merge + 1); + // build leaves for stream_offset in (0..stream_count).step_by(max_streams_per_merge) { let limit = std::cmp::min(stream_offset + max_streams_per_merge, stream_count); @@ -63,7 +65,7 @@ impl SortPreservingCascadeStream { let streams = OffsetCursorStream::new(Arc::clone(&streams), stream_offset, limit); - divided_streams.push(Box::pin(SortPreservingMergeStream::new( + divided_streams.push_back(Box::pin(SortPreservingMergeStream::new( Box::new(streams), metrics.clone(), batch_size, @@ -72,21 +74,36 @@ impl SortPreservingCascadeStream { ))); } - let next_level: CursorStream = - Box::new(YieldedCursorStream::new(divided_streams)); + // build rest of tree + let mut next_level: VecDeque> = + VecDeque::with_capacity(divided_streams.len() / max_streams_per_merge + 1); + while divided_streams.len() > 1 || !next_level.is_empty() { + let fan_in: Vec> = divided_streams + .drain(0..std::cmp::min(max_streams_per_merge, divided_streams.len())) + .collect(); - let root: MergeStream = Box::pin(SortPreservingMergeStream::new( - next_level, - metrics.clone(), - batch_size, - fetch, - reservation, - )); + next_level.push_back(Box::pin(SortPreservingMergeStream::new( + Box::new(YieldedCursorStream::new(fan_in)), + metrics.clone(), + batch_size, + if divided_streams.is_empty() && next_level.is_empty() { + fetch + } else { + None + }, // fetch, the LIMIT, is applied to the final merge + reservation.new_empty(), + ))); + // in order to maintain sort-preserving streams, don't mix the merge tree levels. + if divided_streams.is_empty() { + divided_streams = next_level.drain(..).collect(); + } + } - let cascade = root; Self { aborted: false, - cascade, + cascade: divided_streams + .remove(0) + .expect("must have a root merge stream"), schema, metrics, }