diff --git a/Cargo.lock b/Cargo.lock index d1f291be..352771cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1303,6 +1303,7 @@ dependencies = [ "prost", "prost-types", "pyo3", + "pyo3-async-runtimes", "pyo3-build-config", "tokio", "url", @@ -2672,6 +2673,19 @@ dependencies = [ "unindent", ] +[[package]] +name = "pyo3-async-runtimes" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2529f0be73ffd2be0cc43c013a640796558aa12d7ca0aab5cc14f375b4733031" +dependencies = [ + "futures", + "once_cell", + "pin-project-lite", + "pyo3", + "tokio", +] + [[package]] name = "pyo3-build-config" version = "0.22.6" diff --git a/Cargo.toml b/Cargo.toml index 703fc5a2..d2884468 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ substrait = ["dep:datafusion-substrait"] [dependencies] tokio = { version = "1.41", features = ["macros", "rt", "rt-multi-thread", "sync"] } pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] } +pyo3-async-runtimes = { version = "0.22", features = ["tokio-runtime"]} arrow = { version = "53", features = ["pyarrow"] } datafusion = { version = "43.0.0", features = ["pyarrow", "avro", "unicode_expressions"] } datafusion-substrait = { version = "43.0.0", optional = true } @@ -60,4 +61,4 @@ crate-type = ["cdylib", "rlib"] [profile.release] lto = true -codegen-units = 1 \ No newline at end of file +codegen-units = 1 diff --git a/python/datafusion/record_batch.py b/python/datafusion/record_batch.py index 44936f7d..75e58998 100644 --- a/python/datafusion/record_batch.py +++ b/python/datafusion/record_batch.py @@ -57,20 +57,24 @@ def __init__(self, record_batch_stream: df_internal.RecordBatchStream) -> None: """This constructor is typically not called by the end user.""" self.rbs = record_batch_stream - def next(self) -> RecordBatch | None: + def next(self) -> RecordBatch: """See :py:func:`__next__` for the iterator function.""" - try: - next_batch = next(self) - except StopIteration: - return None + return next(self) - return next_batch + async def __anext__(self) -> RecordBatch: + """Async iterator function.""" + next_batch = await self.rbs.__anext__() + return RecordBatch(next_batch) def __next__(self) -> RecordBatch: """Iterator function.""" next_batch = next(self.rbs) return RecordBatch(next_batch) + def __aiter__(self) -> typing_extensions.Self: + """Async iterator function.""" + return self + def __iter__(self) -> typing_extensions.Self: """Iterator function.""" return self diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index b82f95e3..e3bd1b2a 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -761,8 +761,8 @@ def test_execution_plan(aggregate_df): batch = stream.next() assert batch is not None # there should be no more batches - batch = stream.next() - assert batch is None + with pytest.raises(StopIteration): + stream.next() def test_repartition(df): diff --git a/src/record_batch.rs b/src/record_batch.rs index 427807f2..eacdb586 100644 --- a/src/record_batch.rs +++ b/src/record_batch.rs @@ -15,13 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::utils::wait_for_future; use datafusion::arrow::pyarrow::ToPyArrow; use datafusion::arrow::record_batch::RecordBatch; use datafusion::physical_plan::SendableRecordBatchStream; use futures::StreamExt; +use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration}; use pyo3::prelude::*; use pyo3::{pyclass, pymethods, PyObject, PyResult, Python}; +use tokio::sync::Mutex; #[pyclass(name = "RecordBatch", module = "datafusion", subclass)] pub struct PyRecordBatch { @@ -43,31 +47,58 @@ impl From for PyRecordBatch { #[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)] pub struct PyRecordBatchStream { - stream: SendableRecordBatchStream, + stream: Arc>, } impl PyRecordBatchStream { pub fn new(stream: SendableRecordBatchStream) -> Self { - Self { stream } + Self { + stream: Arc::new(Mutex::new(stream)), + } } } #[pymethods] impl PyRecordBatchStream { - fn next(&mut self, py: Python) -> PyResult> { - let result = self.stream.next(); - match wait_for_future(py, result) { - None => Ok(None), - Some(Ok(b)) => Ok(Some(b.into())), - Some(Err(e)) => Err(e.into()), - } + fn next(&mut self, py: Python) -> PyResult { + let stream = self.stream.clone(); + wait_for_future(py, next_stream(stream, true)) } - fn __next__(&mut self, py: Python) -> PyResult> { + fn __next__(&mut self, py: Python) -> PyResult { self.next(py) } + fn __anext__<'py>(&'py self, py: Python<'py>) -> PyResult> { + let stream = self.stream.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, next_stream(stream, false)) + } + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { slf } + + fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } +} + +async fn next_stream( + stream: Arc>, + sync: bool, +) -> PyResult { + let mut stream = stream.lock().await; + match stream.next().await { + Some(Ok(batch)) => Ok(batch.into()), + Some(Err(e)) => Err(e.into()), + None => { + // Depending on whether the iteration is sync or not, we raise either a + // StopIteration or a StopAsyncIteration + if sync { + Err(PyStopIteration::new_err("stream exhausted")) + } else { + Err(PyStopAsyncIteration::new_err("stream exhausted")) + } + } + } }