Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support async iteration of RecordBatchStream #975

Merged
merged 5 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ substrait = ["dep:datafusion-substrait"]
[dependencies]
tokio = { version = "1.41", features = ["macros", "rt", "rt-multi-thread", "sync"] }
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
pyo3-async-runtimes = { version = "0.22", features = ["tokio-runtime"]}
arrow = { version = "53", features = ["pyarrow"] }
datafusion = { version = "43.0.0", features = ["pyarrow", "avro", "unicode_expressions"] }
datafusion-substrait = { version = "43.0.0", optional = true }
Expand All @@ -60,4 +61,4 @@ crate-type = ["cdylib", "rlib"]

[profile.release]
lto = true
codegen-units = 1
codegen-units = 1
16 changes: 10 additions & 6 deletions python/datafusion/record_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,24 @@ def __init__(self, record_batch_stream: df_internal.RecordBatchStream) -> None:
"""This constructor is typically not called by the end user."""
self.rbs = record_batch_stream

def next(self) -> RecordBatch | None:
def next(self) -> RecordBatch:
"""See :py:func:`__next__` for the iterator function."""
try:
next_batch = next(self)
except StopIteration:
return None
return next(self)

return next_batch
async def __anext__(self) -> RecordBatch:
"""Async iterator function."""
next_batch = await self.rbs.__anext__()
return RecordBatch(next_batch)

def __next__(self) -> RecordBatch:
"""Iterator function."""
next_batch = next(self.rbs)
return RecordBatch(next_batch)

def __aiter__(self) -> typing_extensions.Self:
"""Async iterator function."""
return self

def __iter__(self) -> typing_extensions.Self:
"""Iterator function."""
return self
4 changes: 2 additions & 2 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,8 @@ def test_execution_plan(aggregate_df):
batch = stream.next()
assert batch is not None
# there should be no more batches
batch = stream.next()
assert batch is None
with pytest.raises(StopIteration):
stream.next()


def test_repartition(df):
Expand Down
51 changes: 41 additions & 10 deletions src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use crate::utils::wait_for_future;
use datafusion::arrow::pyarrow::ToPyArrow;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::physical_plan::SendableRecordBatchStream;
use futures::StreamExt;
use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration};
use pyo3::prelude::*;
use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};
use tokio::sync::Mutex;

#[pyclass(name = "RecordBatch", module = "datafusion", subclass)]
pub struct PyRecordBatch {
Expand All @@ -43,31 +47,58 @@ impl From<RecordBatch> for PyRecordBatch {

#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)]
pub struct PyRecordBatchStream {
stream: SendableRecordBatchStream,
stream: Arc<Mutex<SendableRecordBatchStream>>,
}

impl PyRecordBatchStream {
pub fn new(stream: SendableRecordBatchStream) -> Self {
Self { stream }
Self {
stream: Arc::new(Mutex::new(stream)),
}
}
}

#[pymethods]
impl PyRecordBatchStream {
fn next(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
let result = self.stream.next();
match wait_for_future(py, result) {
None => Ok(None),
Some(Ok(b)) => Ok(Some(b.into())),
Some(Err(e)) => Err(e.into()),
}
fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
let stream = self.stream.clone();
wait_for_future(py, next_stream(stream, true))
}

fn __next__(&mut self, py: Python) -> PyResult<Option<PyRecordBatch>> {
fn __next__(&mut self, py: Python) -> PyResult<PyRecordBatch> {
self.next(py)
}

fn __anext__<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let stream = self.stream.clone();
pyo3_async_runtimes::tokio::future_into_py(py, next_stream(stream, false))
}

fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}

fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
}

async fn next_stream(
stream: Arc<Mutex<SendableRecordBatchStream>>,
sync: bool,
) -> PyResult<PyRecordBatch> {
let mut stream = stream.lock().await;
match stream.next().await {
Some(Ok(batch)) => Ok(batch.into()),
Some(Err(e)) => Err(e.into()),
None => {
// Depending on whether the iteration is sync or not, we raise either a
// StopIteration or a StopAsyncIteration
if sync {
Err(PyStopIteration::new_err("stream exhausted"))
} else {
Err(PyStopAsyncIteration::new_err("stream exhausted"))
}
}
}
}
Loading