diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index d3709c03e99a..e30ffd444db9 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -72,6 +72,8 @@ pub struct ArrowReaderBuilder { pub(crate) limit: Option, pub(crate) offset: Option, + + pub(crate) prefetch: bool, } impl ArrowReaderBuilder { @@ -88,6 +90,7 @@ impl ArrowReaderBuilder { selection: None, limit: None, offset: None, + prefetch: false, } } diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index 029567d4ef98..9868494462da 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -413,6 +413,14 @@ impl ParquetRecordBatchStreamBuilder { Ok(Some(Sbbf::new(&bitset))) } + /// For async readers, load data for the next row group while decoding the + /// current row group. + /// + /// Defaults to `false` + pub fn with_prefetch(self, prefetch: bool) -> Self { + Self { prefetch, ..self } + } + /// Build a new [`ParquetRecordBatchStream`] pub fn build(self) -> Result> { let num_row_groups = self.metadata.row_groups().len(); @@ -461,6 +469,8 @@ impl ParquetRecordBatchStreamBuilder { row_groups, projection: self.projection, selection: self.selection, + prefetch_row_groups: self.prefetch, + next_reader: None, schema, reader: Some(reader), state: StreamState::Init, @@ -591,6 +601,8 @@ enum StreamState { Init, /// Decoding a batch Decoding(ParquetRecordBatchReader), + /// Decoding a batch while fetching another row group + Prefetch(ParquetRecordBatchReader, BoxFuture<'static, ReadResult>), /// Reading data from input Reading(BoxFuture<'static, ReadResult>), /// Error @@ -602,6 +614,7 @@ impl std::fmt::Debug for StreamState { match self { StreamState::Init => write!(f, "StreamState::Init"), StreamState::Decoding(_) => write!(f, "StreamState::Decoding"), + StreamState::Prefetch(..) => write!(f, "StreamState::Prefetch"), StreamState::Reading(_) => write!(f, "StreamState::Reading"), StreamState::Error => write!(f, "StreamState::Error"), } @@ -623,6 +636,11 @@ pub struct ParquetRecordBatchStream { selection: Option, + prefetch_row_groups: bool, + + /// The next row group to decode if we are prefetching. + next_reader: Option, + /// This is an option so it can be moved into a future reader: Option>, @@ -651,6 +669,33 @@ impl ParquetRecordBatchStream { } } +impl ParquetRecordBatchStream +where + T: AsyncFileReader + 'static, +{ + /// Returns a future for reading row group `row_group_idx`. + /// + /// Note: this function should only be called in [`StreamState::Init`] and + /// [`StreamState::Decoding`] as this takes [`ParquetRecordBatchStream::reader`] + /// and panics if it does not exist. + fn read_row_group(&mut self, row_group_idx: usize) -> BoxFuture<'static, ReadResult> { + let reader = self.reader.take().expect("lost reader"); + + let row_count = self.metadata.row_group(row_group_idx).num_rows() as usize; + + let selection = self.selection.as_mut().map(|s| s.split_off(row_count)); + + reader + .read_row_group( + row_group_idx, + selection, + self.projection.clone(), + self.batch_size, + ) + .boxed() + } +} + impl Stream for ParquetRecordBatchStream where T: AsyncFileReader + Unpin + Send + 'static, @@ -660,36 +705,89 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { match &mut self.state { - StreamState::Decoding(batch_reader) => match batch_reader.next() { - Some(Ok(batch)) => { - return Poll::Ready(Some(Ok(batch))); + StreamState::Decoding(batch_reader) => { + let res: Self::Item = match batch_reader.next() { + Some(Ok(batch)) => Ok(batch), + Some(Err(e)) => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(ParquetError::ArrowError(e.to_string())))); + } + None => { + self.state = StreamState::Init; + continue; + } + }; + + if !self.prefetch_row_groups + || self.row_groups.is_empty() + || self.next_reader.is_some() + { + return Poll::Ready(Some(res)); } - Some(Err(e)) => { - self.state = StreamState::Error; - return Poll::Ready(Some(Err(ParquetError::ArrowError(e.to_string())))); + + let old_state = std::mem::replace(&mut self.state, StreamState::Init); + + let row_group_idx = self.row_groups.pop_front().unwrap(); // already checked that row_groups is not empty + + let fut = self.read_row_group(row_group_idx); + + if let StreamState::Decoding(batch_reader) = old_state { + self.state = StreamState::Prefetch(batch_reader, fut); + return Poll::Ready(Some(res)); + } else { + unreachable!() } - None => self.state = StreamState::Init, - }, + } + StreamState::Prefetch(batch_reader, f) => { + let mut noop_cx = Context::from_waker(futures::task::noop_waker_ref()); + match f.poll_unpin(&mut noop_cx) { + Poll::Pending => (), + Poll::Ready(Ok((reader_factory, maybe_reader))) => { + let old_state = std::mem::replace(&mut self.state, StreamState::Init); + if let StreamState::Prefetch(batch_reader, _) = old_state { + self.state = StreamState::Decoding(batch_reader); + } else { + unreachable!() + } + self.reader = Some(reader_factory); + self.next_reader = maybe_reader; + continue; + } + Poll::Ready(Err(e)) => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(e))); + } + } + + match batch_reader.next() { + Some(Ok(batch)) => return Poll::Ready(Some(Ok(batch))), + Some(Err(e)) => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(ParquetError::ArrowError(e.to_string())))); + } + None => { + let old_state = std::mem::replace(&mut self.state, StreamState::Init); + if let StreamState::Prefetch(_, f) = old_state { + self.state = StreamState::Reading(f); + continue; + } else { + unreachable!() + } + } + } + } StreamState::Init => { + if let Some(batch_reader) = self.next_reader.take() { + self.state = StreamState::Decoding(batch_reader); + continue; + } + let row_group_idx = match self.row_groups.pop_front() { Some(idx) => idx, None => return Poll::Ready(None), }; - let reader = self.reader.take().expect("lost reader"); - - let row_count = self.metadata.row_group(row_group_idx).num_rows() as usize; - - let selection = self.selection.as_mut().map(|s| s.split_off(row_count)); - - let fut = reader - .read_row_group( - row_group_idx, - selection, - self.projection.clone(), - self.batch_size, - ) - .boxed(); + let fut = self.read_row_group(row_group_idx); self.state = StreamState::Reading(fut) } @@ -2037,4 +2135,79 @@ mod tests { // Should only have made 3 requests assert_eq!(requests.lock().unwrap().len(), 3); } + + #[tokio::test] + async fn test_reader_prefetch() { + let a = StringArray::from_iter_values(["a", "b", "b", "b", "c", "c"]); + let b = StringArray::from_iter_values(["1", "2", "3", "4", "5", "6"]); + let c = Int32Array::from_iter(0..6); + let data = RecordBatch::try_from_iter([ + ("a", Arc::new(a) as ArrayRef), + ("b", Arc::new(b) as ArrayRef), + ("c", Arc::new(c) as ArrayRef), + ]) + .unwrap(); + + let mut buf = Vec::with_capacity(1024); + let props = WriterProperties::builder() + .set_max_row_group_size(3) + .build(); + let mut writer = ArrowWriter::try_new(&mut buf, data.schema(), Some(props)).unwrap(); + writer.write(&data).unwrap(); + writer.close().unwrap(); + + let data: Bytes = buf.into(); + let metadata = ParquetMetaDataReader::new() + .parse_and_finish(&data) + .unwrap(); + + assert_eq!(metadata.num_row_groups(), 2); + + let test = TestReader { + data, + metadata: Arc::new(metadata), + requests: Default::default(), + }; + + let mut stream = ParquetRecordBatchStreamBuilder::new(test.clone()) + .await + .unwrap() + .with_batch_size(1) + .with_limit(5) + .with_prefetch(true) + .build() + .unwrap(); + + let batch1 = stream.try_next().await.unwrap().unwrap(); + // Each batch should only have one row + assert_eq!(batch1.num_rows(), 1); + // Make sure we are pre-fetching + assert!(matches!(stream.state, StreamState::Prefetch(..))); + + let batch2 = stream.try_next().await.unwrap().unwrap(); + assert_eq!(batch2.num_rows(), 1); + // We should no longer be prefetching... + assert!(matches!(stream.state, StreamState::Decoding(..))); + // because we already fetched the next row group. + assert!(stream.next_reader.is_some()); + + let batch3 = stream.try_next().await.unwrap().unwrap(); + assert_eq!(batch3.num_rows(), 1); + assert!(matches!(stream.state, StreamState::Decoding(..))); + assert!(stream.next_reader.is_some()); + + let batch4 = stream.try_next().await.unwrap().unwrap(); + assert_eq!(batch4.num_rows(), 1); + assert!(matches!(stream.state, StreamState::Decoding(..))); + assert!(stream.next_reader.is_none()); + + let batch5 = stream.try_next().await.unwrap().unwrap(); + assert_eq!(batch5.num_rows(), 1); + assert!(matches!(stream.state, StreamState::Decoding(..))); + assert!(stream.next_reader.is_none()); + + let batch6 = stream.try_next().await.unwrap(); + assert!(batch6.is_none()); + assert!(stream.next_reader.is_none()); + } }