Skip to content

Commit

Permalink
feat: add PyFuture to await Python awaitables
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Nov 30, 2023
1 parent a7679ec commit 55a0e82
Show file tree
Hide file tree
Showing 11 changed files with 605 additions and 131 deletions.
1 change: 1 addition & 0 deletions guide/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
- [Python exceptions](exception.md)
- [Calling Python from Rust](python_from_rust.md)
- [Using `async` and `await`](async-await.md)
- [Awaiting Python awaitables](async-await/pyfuture.md)
- [GIL, mutability and object types](types.md)
- [Parallelism](parallelism.md)
- [Debugging](debugging.md)
Expand Down
51 changes: 51 additions & 0 deletions guide/src/async-await/pyfuture.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Awaiting Python awaitables

Python awaitable can be awaited on Rust side using [`PyFuture`]({{#PYO3_DOCS_URL}}/pyo3/types/struct.PyFuture.html).

```rust
# #![allow(dead_code)]
use pyo3::{prelude::*, types::PyFuture};

#[pyfunction]
async fn wrap_awaitable(awaitable: PyObject) -> PyResult<PyObject> {
let future: Py<PyFuture> = Python::with_gil(|gil| Py::from_object(gil, awaitable))?;
future.await
}
```

`PyFuture::from_object` construct a `PyFuture` from a Python awaitable object, by calling its `__await__` method (or `__iter__` for generator-based coroutine).

## Restrictions

`PyFuture` can only be awaited in the context of a PyO3 coroutine. Otherwise, it panics.

```rust
# #![allow(dead_code)]
use pyo3::{prelude::*, types::PyFuture};

#[pyfunction]
fn block_on(awaitable: PyObject) -> PyResult<PyObject> {
let future: Py<PyFuture> = Python::with_gil(|gil| Py::from_object(gil, awaitable))?;
futures::executor::block_on(future) // ERROR: PyFuture must be awaited in coroutine context
}
```

`PyFuture` must be the only Rust future awaited; it means that it's forbidden to `select!` a `Pyfuture`. Otherwise, it panics.

```rust
# #![allow(dead_code)]
use std::future;
use futures::FutureExt;
use pyo3::{prelude::*, types::PyFuture};

#[pyfunction]
async fn select(awaitable: PyObject) -> PyResult<PyObject> {
let future: Py<PyFuture> = Python::with_gil(|gil| Py::from_object(gil, awaitable))?;
futures::select_biased! {
_ = future::pending::<()>().fuse() => unreachable!(),
res = future.fuse() => res, // ERROR: Python awaitable mixed with Rust future
}
}
```

These restrictions exist because awaiting a `PyFuture` strongly binds it to the enclosing coroutine. The coroutine will then delegate its `send`/`throw`/`close` methods to the awaited `PyFuture`. If it was awaited in a `select!`, `Coroutine::send` would no able to know if the value passed would have to be delegated to the `Pyfuture` or not.
6 changes: 5 additions & 1 deletion pyo3-ffi/src/abstract_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ extern "C" {
pub fn PyIter_Next(arg1: *mut PyObject) -> *mut PyObject;
#[cfg(all(not(PyPy), Py_3_10))]
#[cfg_attr(PyPy, link_name = "PyPyIter_Send")]
pub fn PyIter_Send(iter: *mut PyObject, arg: *mut PyObject, presult: *mut *mut PyObject);
pub fn PyIter_Send(
iter: *mut PyObject,
arg: *mut PyObject,
presult: *mut *mut PyObject,
) -> c_int;

#[cfg_attr(PyPy, link_name = "PyPyNumber_Check")]
pub fn PyNumber_Check(o: *mut PyObject) -> c_int;
Expand Down
97 changes: 49 additions & 48 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
//! Python coroutine implementation, used notably when wrapping `async fn`
//! with `#[pyfunction]`/`#[pymethods]`.
use std::task::Waker;
use std::{
future::Future,
panic,
pin::Pin,
sync::Arc,
task::{Context, Poll},
task::{Context, Poll, Waker},
};

use pyo3_macros::{pyclass, pymethods};

use crate::{
coroutine::waker::AsyncioWaker,
coroutine::waker::CoroutineWaker,
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
pyclass::IterNextOutput,
types::{PyIterator, PyString},
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
types::PyString,
IntoPy, Py, PyErr, PyObject, PyResult, Python,
};

mod asyncio;
pub(crate) mod cancel;
mod waker;
pub(crate) mod waker;

use crate::coroutine::cancel::ThrowCallback;
use crate::panic::PanicException;
Expand All @@ -36,7 +36,7 @@ pub struct Coroutine {
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
waker: Option<Arc<AsyncioWaker>>,
waker: Option<Arc<CoroutineWaker>>,
}

impl Coroutine {
Expand Down Expand Up @@ -73,33 +73,37 @@ impl Coroutine {
}
}

fn poll(
fn poll_inner(
&mut self,
py: Python<'_>,
throw: Option<PyObject>,
mut sent_result: Option<Result<PyObject, PyObject>>,
) -> PyResult<IterNextOutput<PyObject, PyObject>> {
// raise if the coroutine has already been run to completion
let future_rs = match self.future {
Some(ref mut fut) => fut,
None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
};
// reraise thrown exception it
match (throw, &self.throw_callback) {
(Some(exc), Some(cb)) => cb.throw(exc.as_ref(py)),
(Some(exc), None) => {
self.close();
return Err(PyErr::from_value(exc.as_ref(py)));
// if the future is not pending on a Python awaitable,
// execute throw callback or complete on close
if !matches!(self.waker, Some(ref w) if w.yielded_from_awaitable(py)) {
match (sent_result, &self.throw_callback) {
(res @ Some(Ok(_)), _) => sent_result = res,
(Some(Err(err)), Some(cb)) => {
cb.throw(err.as_ref(py));
sent_result = Some(Ok(py.None().into()));
}
(Some(Err(err)), None) => return Err(PyErr::from_value(err.as_ref(py))),
(None, _) => return Ok(IterNextOutput::Return(py.None().into())),
}
_ => {}
}
// create a new waker, or try to reset it in place
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
waker.reset();
waker.reset(sent_result);
} else {
self.waker = Some(Arc::new(AsyncioWaker::new()));
self.waker = Some(Arc::new(CoroutineWaker::new(sent_result)));
}
let waker = Waker::from(self.waker.clone().unwrap());
// poll the Rust future and forward its results if ready
// poll the Rust future and forward its results if ready; otherwise, yield from waker
// polling is UnwindSafe because the future is dropped in case of panic
let poll = || {
if self.allow_threads {
Expand All @@ -109,29 +113,27 @@ impl Coroutine {
}
};
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
Ok(Poll::Ready(res)) => {
self.close();
return Ok(IterNextOutput::Return(res?));
}
Err(err) => {
self.close();
return Err(PanicException::from_panic_payload(err));
}
_ => {}
Err(err) => Err(PanicException::from_panic_payload(err)),
Ok(Poll::Ready(res)) => Ok(IterNextOutput::Return(res?)),
Ok(Poll::Pending) => match self.waker.as_ref().unwrap().yield_(py) {
Ok(to_yield) => Ok(IterNextOutput::Yield(to_yield)),
Err(err) => Err(err),
},
}
// otherwise, initialize the waker `asyncio.Future`
if let Some(future) = self.waker.as_ref().unwrap().initialize_future(py)? {
// `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__`
// and will yield itself if its result has not been set in polling above
if let Some(future) = PyIterator::from_object(future).unwrap().next() {
// future has not been leaked into Python for now, and Rust code can only call
// `set_result(None)` in `Wake` implementation, so it's safe to unwrap
return Ok(IterNextOutput::Yield(future.unwrap().into()));
}
}

fn poll(
&mut self,
py: Python<'_>,
sent_result: Option<Result<PyObject, PyObject>>,
) -> PyResult<IterNextOutput<PyObject, PyObject>> {
let result = self.poll_inner(py, sent_result);
if matches!(result, Ok(IterNextOutput::Return(_)) | Err(_)) {
// the Rust future is dropped, and the field set to `None`
// to indicate the coroutine has been run to completion
drop(self.future.take());
}
// if waker has been waken during future polling, this is roughly equivalent to
// `await asyncio.sleep(0)`, so just yield `None`.
Ok(IterNextOutput::Yield(py.None().into()))
result
}
}

Expand Down Expand Up @@ -163,25 +165,24 @@ impl Coroutine {
}
}

fn send(&mut self, py: Python<'_>, _value: &PyAny) -> PyResult<PyObject> {
iter_result(self.poll(py, None)?)
fn send(&mut self, py: Python<'_>, value: PyObject) -> PyResult<PyObject> {
iter_result(self.poll(py, Some(Ok(value)))?)
}

fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult<PyObject> {
iter_result(self.poll(py, Some(exc))?)
iter_result(self.poll(py, Some(Err(exc)))?)
}

fn close(&mut self) {
// the Rust future is dropped, and the field set to `None`
// to indicate the coroutine has been run to completion
drop(self.future.take());
fn close(&mut self, py: Python<'_>) -> PyResult<()> {
self.poll(py, None)?;
Ok(())
}

fn __await__(self_: Py<Self>) -> Py<Self> {
self_
}

fn __next__(&mut self, py: Python<'_>) -> PyResult<IterNextOutput<PyObject, PyObject>> {
self.poll(py, None)
self.poll(py, Some(Ok(py.None().into())))
}
}
90 changes: 90 additions & 0 deletions src/coroutine/asyncio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//! Coroutine implementation compatible with asyncio.
use crate::sync::GILOnceCell;
use crate::types::{PyCFunction, PyIterator};
use crate::{intern, wrap_pyfunction, IntoPy, Py, PyAny, PyObject, PyResult, Python};
use pyo3_macros::pyfunction;

/// `asyncio.get_running_loop`
fn get_running_loop(py: Python<'_>) -> PyResult<&PyAny> {
static GET_RUNNING_LOOP: GILOnceCell<PyObject> = GILOnceCell::new();
let import = || -> PyResult<_> {
let module = py.import("asyncio")?;
Ok(module.getattr("get_running_loop")?.into())
};
GET_RUNNING_LOOP
.get_or_try_init(py, import)?
.as_ref(py)
.call0()
}

/// Asyncio-compatible coroutine waker.
///
/// Polling a Rust future yields an `asyncio.Future`, whose `set_result` method is called
/// when `Waker::wake` is called.
pub(super) struct AsyncioWaker {
event_loop: PyObject,
future: PyObject,
}

impl AsyncioWaker {
pub(super) fn new(py: Python<'_>) -> PyResult<Self> {
let event_loop = get_running_loop(py)?.into_py(py);
let future = event_loop.call_method0(py, "create_future")?;
Ok(Self { event_loop, future })
}

pub(super) fn yield_(&self, py: Python<'_>) -> PyResult<PyObject> {
let __await__;
// `asyncio.Future` must be awaited; in normal case, it implements `__iter__ = __await__`,
// but `create_future` may have been overriden
let mut iter = match PyIterator::from_object(self.future.as_ref(py)) {
Ok(iter) => iter,
Err(_) => {
__await__ = self.future.call_method0(py, intern!(py, "__await__"))?;
PyIterator::from_object(__await__.as_ref(py))?
}
};
// future has not been waken (because `yield_waken` would have been called
// otherwise), so it is expected to yield itself
Ok(iter.next().expect("future didn't yield")?.into_py(py))
}

pub(super) fn yield_waken(py: Python<'_>) -> PyResult<PyObject> {
Ok(py.None().into())
}

pub(super) fn wake(&self, py: Python<'_>) -> PyResult<()> {
static RELEASE_WAITER: GILOnceCell<Py<PyCFunction>> = GILOnceCell::new();
let release_waiter = RELEASE_WAITER
.get_or_try_init(py, || wrap_pyfunction!(release_waiter, py).map(Into::into))?;
// `Future.set_result` must be called in event loop thread,
// so it requires `call_soon_threadsafe`
let call_soon_threadsafe = self.event_loop.call_method1(
py,
intern!(py, "call_soon_threadsafe"),
(release_waiter, self.future.as_ref(py)),
);
if let Err(err) = call_soon_threadsafe {
// `call_soon_threadsafe` will raise if the event loop is closed;
// instead of catching an unspecific `RuntimeError`, check directly if it's closed.
let is_closed = self.event_loop.call_method0(py, "is_closed")?;
if !is_closed.extract(py)? {
return Err(err);
}
}
Ok(())
}
}

/// Call `future.set_result` if the future is not done.
///
/// Future can be cancelled by the event loop before being waken.
/// See <https://github.com/python/cpython/blob/main/Lib/asyncio/tasks.py#L452C5-L452C5>
#[pyfunction(crate = "crate")]
fn release_waiter(future: &PyAny) -> PyResult<()> {
let done = future.call_method0(intern!(future.py(), "done"))?;
if !done.extract::<bool>()? {
future.call_method1(intern!(future.py(), "set_result"), (future.py().None(),))?;
}
Ok(())
}
Loading

0 comments on commit 55a0e82

Please sign in to comment.