From 4f682b1277261dafbb40297529dc3381b0d1c32b Mon Sep 17 00:00:00 2001 From: Mason Hall Date: Sat, 26 Oct 2024 23:14:31 -0400 Subject: [PATCH 1/3] Squashed commit of the following: commit 28b2bf145086894f9997bfd6f10734361ec67523 Author: Mason Hall Date: Fri Oct 18 15:49:31 2024 -0400 Cleaned up prefetch and added a test commit 3d7e018bd2f61f707e8e189cf8649fe9c067dbe7 Author: Mason Hall Date: Fri Oct 18 13:32:22 2024 -0400 prefetch working --- parquet/src/arrow/arrow_reader/mod.rs | 3 + parquet/src/arrow/async_reader/mod.rs | 242 +++++++++++++++++++++++--- 2 files changed, 224 insertions(+), 21 deletions(-) diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index d3709c03e99a..e30ffd444db9 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -72,6 +72,8 @@ pub struct ArrowReaderBuilder { pub(crate) limit: Option, pub(crate) offset: Option, + + pub(crate) prefetch: bool, } impl ArrowReaderBuilder { @@ -88,6 +90,7 @@ impl ArrowReaderBuilder { selection: None, limit: None, offset: None, + prefetch: false, } } diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index 029567d4ef98..b27dd3605e92 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -413,6 +413,17 @@ impl ParquetRecordBatchStreamBuilder { Ok(Some(Sbbf::new(&bitset))) } + /// For async readers, load data for the next row group while decoding the + /// current row group. + /// + /// Defaults to `false` + pub fn with_prefetch(self, prefetch: bool) -> Self { + Self { + prefetch, + ..self + } + } + /// Build a new [`ParquetRecordBatchStream`] pub fn build(self) -> Result> { let num_row_groups = self.metadata.row_groups().len(); @@ -461,6 +472,8 @@ impl ParquetRecordBatchStreamBuilder { row_groups, projection: self.projection, selection: self.selection, + prefetch_row_groups: self.prefetch, + next_reader: None, schema, reader: Some(reader), state: StreamState::Init, @@ -591,6 +604,8 @@ enum StreamState { Init, /// Decoding a batch Decoding(ParquetRecordBatchReader), + /// Decoding a batch while fetching another row group + Prefetch(ParquetRecordBatchReader, BoxFuture<'static, ReadResult>), /// Reading data from input Reading(BoxFuture<'static, ReadResult>), /// Error @@ -602,6 +617,7 @@ impl std::fmt::Debug for StreamState { match self { StreamState::Init => write!(f, "StreamState::Init"), StreamState::Decoding(_) => write!(f, "StreamState::Decoding"), + StreamState::Prefetch(..) => write!(f, "StreamState::Prefetch"), StreamState::Reading(_) => write!(f, "StreamState::Reading"), StreamState::Error => write!(f, "StreamState::Error"), } @@ -623,6 +639,11 @@ pub struct ParquetRecordBatchStream { selection: Option, + prefetch_row_groups: bool, + + /// The next row group to decode if we are prefetching. + next_reader: Option, + /// This is an option so it can be moved into a future reader: Option>, @@ -651,6 +672,31 @@ impl ParquetRecordBatchStream { } } +impl ParquetRecordBatchStream +where + T: AsyncFileReader + 'static, +{ + /// Returns a future for reading row group `row_group_idx`. + /// + /// Note: this function should only be called in [`StreamState::Init`] and + /// [`StreamState::Decoding`] as this takes [`ParquetRecordBatchStream::reader`] + /// and panics if it does not exist. + fn read_row_group(&mut self, row_group_idx: usize) -> BoxFuture<'static, ReadResult> { + let reader = self.reader.take().expect("lost reader"); + + let row_count = self.metadata.row_group(row_group_idx).num_rows() as usize; + + let selection = self.selection.as_mut().map(|s| s.split_off(row_count)); + + reader.read_row_group( + row_group_idx, + selection, + self.projection.clone(), + self.batch_size, + ).boxed() + } +} + impl Stream for ParquetRecordBatchStream where T: AsyncFileReader + Unpin + Send + 'static, @@ -660,36 +706,99 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { match &mut self.state { - StreamState::Decoding(batch_reader) => match batch_reader.next() { - Some(Ok(batch)) => { - return Poll::Ready(Some(Ok(batch))); + StreamState::Decoding(batch_reader) => { + let res: Self::Item; + match batch_reader.next() { + Some(Ok(batch)) => { + res = Ok(batch); + } + Some(Err(e)) => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(ParquetError::ArrowError(e.to_string())))); + } + None => { + self.state = StreamState::Init; + continue + } } - Some(Err(e)) => { - self.state = StreamState::Error; - return Poll::Ready(Some(Err(ParquetError::ArrowError(e.to_string())))); + + if !self.prefetch_row_groups + || self.row_groups.is_empty() + || self.next_reader.is_some() + { + return Poll::Ready(Some(res)) + } + + let old_state = std::mem::replace(&mut self.state, StreamState::Init); + + let row_group_idx = self + .row_groups + .pop_front() + .unwrap(); // already checked that row_groups is not empty + + let fut = self.read_row_group(row_group_idx); + + if let StreamState::Decoding(batch_reader) = old_state { + self.state = StreamState::Prefetch(batch_reader, fut); + return Poll::Ready(Some(res)) + } else { + unreachable!() + } + }, + StreamState::Prefetch(batch_reader, f) => { + let mut noop_cx = Context::from_waker( + futures::task::noop_waker_ref() + ); + match f.poll_unpin(&mut noop_cx) { + Poll::Pending => (), + Poll::Ready(Ok((reader_factory, maybe_reader))) => { + let old_state = std::mem::replace(&mut self.state, StreamState::Init); + if let StreamState::Prefetch(batch_reader, _) = old_state { + self.state = StreamState::Decoding(batch_reader); + } else { + unreachable!() + } + self.reader = Some(reader_factory); + self.next_reader = maybe_reader; + continue + }, + Poll::Ready(Err(e)) => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(e))) + }, + } + + match batch_reader.next() { + Some(Ok(batch)) => { + return Poll::Ready(Some(Ok(batch))) + } + Some(Err(e)) => { + self.state = StreamState::Error; + return Poll::Ready(Some(Err(ParquetError::ArrowError(e.to_string())))); + } + None => { + let old_state = std::mem::replace(&mut self.state, StreamState::Init); + if let StreamState::Prefetch(_, f) = old_state { + self.state = StreamState::Reading(f); + return Poll::Pending + } else { + unreachable!() + } + } } - None => self.state = StreamState::Init, }, StreamState::Init => { + if let Some(batch_reader) = self.next_reader.take() { + self.state = StreamState::Decoding(batch_reader); + continue + } + let row_group_idx = match self.row_groups.pop_front() { Some(idx) => idx, None => return Poll::Ready(None), }; - let reader = self.reader.take().expect("lost reader"); - - let row_count = self.metadata.row_group(row_group_idx).num_rows() as usize; - - let selection = self.selection.as_mut().map(|s| s.split_off(row_count)); - - let fut = reader - .read_row_group( - row_group_idx, - selection, - self.projection.clone(), - self.batch_size, - ) - .boxed(); + let fut = self.read_row_group(row_group_idx); self.state = StreamState::Reading(fut) } @@ -2037,4 +2146,95 @@ mod tests { // Should only have made 3 requests assert_eq!(requests.lock().unwrap().len(), 3); } + + #[tokio::test] + async fn test_reader_prefetch() {let a = StringArray::from_iter_values(["a", "b", "b", "b", "c", "c"]); + let b = StringArray::from_iter_values(["1", "2", "3", "4", "5", "6"]); + let c = Int32Array::from_iter(0..6); + let data = RecordBatch::try_from_iter([ + ("a", Arc::new(a) as ArrayRef), + ("b", Arc::new(b) as ArrayRef), + ("c", Arc::new(c) as ArrayRef), + ]) + .unwrap(); + + let mut buf = Vec::with_capacity(1024); + let props = WriterProperties::builder() + .set_max_row_group_size(3) + .build(); + let mut writer = ArrowWriter::try_new(&mut buf, data.schema(), Some(props)).unwrap(); + writer.write(&data).unwrap(); + writer.close().unwrap(); + + let data: Bytes = buf.into(); + let metadata = ParquetMetaDataReader::new() + .parse_and_finish(&data) + .unwrap(); + + assert_eq!(metadata.num_row_groups(), 2); + + let test = TestReader { + data, + metadata: Arc::new(metadata), + requests: Default::default(), + }; + + let mut stream = ParquetRecordBatchStreamBuilder::new(test.clone()) + .await + .unwrap() + .with_batch_size(1) + .with_limit(5) + .with_prefetch(true) + .build() + .unwrap(); + + let batch1 = stream.try_next() + .await + .unwrap() + .unwrap(); + // Each batch should only have one row + assert_eq!(batch1.num_rows(), 1); + // Make sure we are pre-fetching + assert!(matches!(stream.state, StreamState::Prefetch(..))); + + let batch2 = stream.try_next() + .await + .unwrap() + .unwrap(); + assert_eq!(batch2.num_rows(), 1); + // We should no longer be prefetching... + assert!(matches!(stream.state, StreamState::Decoding(..))); + // because we already fetched the next row group. + assert!(stream.next_reader.is_some()); + + let batch3 = stream.try_next() + .await + .unwrap() + .unwrap(); + assert_eq!(batch3.num_rows(), 1); + assert!(matches!(stream.state, StreamState::Decoding(..))); + assert!(stream.next_reader.is_some()); + + let batch4 = stream.try_next() + .await + .unwrap() + .unwrap(); + assert_eq!(batch4.num_rows(), 1); + assert!(matches!(stream.state, StreamState::Decoding(..))); + assert!(stream.next_reader.is_none()); + + let batch5 = stream.try_next() + .await + .unwrap() + .unwrap(); + assert_eq!(batch5.num_rows(), 1); + assert!(matches!(stream.state, StreamState::Decoding(..))); + assert!(stream.next_reader.is_none()); + + let batch6 = stream.try_next() + .await + .unwrap(); + assert!(batch6.is_none()); + assert!(stream.next_reader.is_none()); + } } From d9c71b91232afc35b2ce7d2a6f4f85f198d01876 Mon Sep 17 00:00:00 2001 From: Mason Hall Date: Wed, 30 Oct 2024 17:35:43 -0400 Subject: [PATCH 2/3] Fix future not waking due to noop_cx --- parquet/src/arrow/async_reader/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index b27dd3605e92..1a961a7087c1 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -780,7 +780,7 @@ where let old_state = std::mem::replace(&mut self.state, StreamState::Init); if let StreamState::Prefetch(_, f) = old_state { self.state = StreamState::Reading(f); - return Poll::Pending + continue } else { unreachable!() } From 9eb6570a9c7aafa62b414d51da1aff49cafcc2b0 Mon Sep 17 00:00:00 2001 From: Mason Hall Date: Wed, 30 Oct 2024 19:59:29 -0400 Subject: [PATCH 3/3] rustfmt and clippy fixes --- parquet/src/arrow/async_reader/mod.rs | 95 ++++++++++----------------- 1 file changed, 34 insertions(+), 61 deletions(-) diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index 1a961a7087c1..9868494462da 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -418,10 +418,7 @@ impl ParquetRecordBatchStreamBuilder { /// /// Defaults to `false` pub fn with_prefetch(self, prefetch: bool) -> Self { - Self { - prefetch, - ..self - } + Self { prefetch, ..self } } /// Build a new [`ParquetRecordBatchStream`] @@ -688,12 +685,14 @@ where let selection = self.selection.as_mut().map(|s| s.split_off(row_count)); - reader.read_row_group( - row_group_idx, - selection, - self.projection.clone(), - self.batch_size, - ).boxed() + reader + .read_row_group( + row_group_idx, + selection, + self.projection.clone(), + self.batch_size, + ) + .boxed() } } @@ -707,48 +706,40 @@ where loop { match &mut self.state { StreamState::Decoding(batch_reader) => { - let res: Self::Item; - match batch_reader.next() { - Some(Ok(batch)) => { - res = Ok(batch); - } + let res: Self::Item = match batch_reader.next() { + Some(Ok(batch)) => Ok(batch), Some(Err(e)) => { self.state = StreamState::Error; return Poll::Ready(Some(Err(ParquetError::ArrowError(e.to_string())))); } None => { self.state = StreamState::Init; - continue + continue; } - } + }; if !self.prefetch_row_groups || self.row_groups.is_empty() || self.next_reader.is_some() { - return Poll::Ready(Some(res)) + return Poll::Ready(Some(res)); } let old_state = std::mem::replace(&mut self.state, StreamState::Init); - let row_group_idx = self - .row_groups - .pop_front() - .unwrap(); // already checked that row_groups is not empty + let row_group_idx = self.row_groups.pop_front().unwrap(); // already checked that row_groups is not empty let fut = self.read_row_group(row_group_idx); if let StreamState::Decoding(batch_reader) = old_state { self.state = StreamState::Prefetch(batch_reader, fut); - return Poll::Ready(Some(res)) + return Poll::Ready(Some(res)); } else { unreachable!() } - }, + } StreamState::Prefetch(batch_reader, f) => { - let mut noop_cx = Context::from_waker( - futures::task::noop_waker_ref() - ); + let mut noop_cx = Context::from_waker(futures::task::noop_waker_ref()); match f.poll_unpin(&mut noop_cx) { Poll::Pending => (), Poll::Ready(Ok((reader_factory, maybe_reader))) => { @@ -760,18 +751,16 @@ where } self.reader = Some(reader_factory); self.next_reader = maybe_reader; - continue - }, + continue; + } Poll::Ready(Err(e)) => { self.state = StreamState::Error; - return Poll::Ready(Some(Err(e))) - }, + return Poll::Ready(Some(Err(e))); + } } match batch_reader.next() { - Some(Ok(batch)) => { - return Poll::Ready(Some(Ok(batch))) - } + Some(Ok(batch)) => return Poll::Ready(Some(Ok(batch))), Some(Err(e)) => { self.state = StreamState::Error; return Poll::Ready(Some(Err(ParquetError::ArrowError(e.to_string())))); @@ -780,17 +769,17 @@ where let old_state = std::mem::replace(&mut self.state, StreamState::Init); if let StreamState::Prefetch(_, f) = old_state { self.state = StreamState::Reading(f); - continue + continue; } else { unreachable!() } } } - }, + } StreamState::Init => { if let Some(batch_reader) = self.next_reader.take() { self.state = StreamState::Decoding(batch_reader); - continue + continue; } let row_group_idx = match self.row_groups.pop_front() { @@ -2148,7 +2137,8 @@ mod tests { } #[tokio::test] - async fn test_reader_prefetch() {let a = StringArray::from_iter_values(["a", "b", "b", "b", "c", "c"]); + async fn test_reader_prefetch() { + let a = StringArray::from_iter_values(["a", "b", "b", "b", "c", "c"]); let b = StringArray::from_iter_values(["1", "2", "3", "4", "5", "6"]); let c = Int32Array::from_iter(0..6); let data = RecordBatch::try_from_iter([ @@ -2188,52 +2178,35 @@ mod tests { .build() .unwrap(); - let batch1 = stream.try_next() - .await - .unwrap() - .unwrap(); + let batch1 = stream.try_next().await.unwrap().unwrap(); // Each batch should only have one row assert_eq!(batch1.num_rows(), 1); // Make sure we are pre-fetching assert!(matches!(stream.state, StreamState::Prefetch(..))); - let batch2 = stream.try_next() - .await - .unwrap() - .unwrap(); + let batch2 = stream.try_next().await.unwrap().unwrap(); assert_eq!(batch2.num_rows(), 1); // We should no longer be prefetching... assert!(matches!(stream.state, StreamState::Decoding(..))); // because we already fetched the next row group. assert!(stream.next_reader.is_some()); - let batch3 = stream.try_next() - .await - .unwrap() - .unwrap(); + let batch3 = stream.try_next().await.unwrap().unwrap(); assert_eq!(batch3.num_rows(), 1); assert!(matches!(stream.state, StreamState::Decoding(..))); assert!(stream.next_reader.is_some()); - let batch4 = stream.try_next() - .await - .unwrap() - .unwrap(); + let batch4 = stream.try_next().await.unwrap().unwrap(); assert_eq!(batch4.num_rows(), 1); assert!(matches!(stream.state, StreamState::Decoding(..))); assert!(stream.next_reader.is_none()); - let batch5 = stream.try_next() - .await - .unwrap() - .unwrap(); + let batch5 = stream.try_next().await.unwrap().unwrap(); assert_eq!(batch5.num_rows(), 1); assert!(matches!(stream.state, StreamState::Decoding(..))); assert!(stream.next_reader.is_none()); - let batch6 = stream.try_next() - .await - .unwrap(); + let batch6 = stream.try_next().await.unwrap(); assert!(batch6.is_none()); assert!(stream.next_reader.is_none()); }