Skip to content

Commit

Permalink
feat: add #[pyo3(allow_threads)] to release the GIL in (async) func…
Browse files Browse the repository at this point in the history
…tions
  • Loading branch information
wyfo committed Nov 30, 2023
1 parent c8fa064 commit 32c2050
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 8 deletions.
1 change: 1 addition & 0 deletions newsfragments/3610.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `#[pyo3(allow_threads)]` to release the GIL in (async) functions
1 change: 1 addition & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use syn::{
};

pub mod kw {
syn::custom_keyword!(allow_threads);
syn::custom_keyword!(annotation);
syn::custom_keyword!(attribute);
syn::custom_keyword!(cancel_handle);
Expand Down
28 changes: 26 additions & 2 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use crate::deprecations::{Deprecation, Deprecations};
use crate::params::impl_arg_params;
use crate::pyfunction::{FunctionSignature, PyFunctionArgPyO3Attributes};
use crate::pyfunction::{PyFunctionOptions, SignatureAttribute};
use crate::quotes;
use crate::utils::{self, PythonDoc};
use crate::{attributes, quotes};
use proc_macro2::{Span, TokenStream};
use quote::ToTokens;
use quote::{format_ident, ToTokens};
use quote::{quote, quote_spanned};
use syn::ext::IdentExt;
use syn::spanned::Spanned;
Expand Down Expand Up @@ -239,6 +239,7 @@ pub struct FnSpec<'a> {
pub asyncness: Option<syn::Token![async]>,
pub unsafety: Option<syn::Token![unsafe]>,
pub deprecations: Deprecations,
pub allow_threads: Option<attributes::kw::allow_threads>,
}

pub fn get_return_info(output: &syn::ReturnType) -> syn::Type {
Expand Down Expand Up @@ -282,6 +283,7 @@ impl<'a> FnSpec<'a> {
text_signature,
name,
signature,
allow_threads,
..
} = options;

Expand Down Expand Up @@ -329,6 +331,7 @@ impl<'a> FnSpec<'a> {
asyncness: sig.asyncness,
unsafety: sig.unsafety,
deprecations,
allow_threads,
})
}

Expand Down Expand Up @@ -472,6 +475,7 @@ impl<'a> FnSpec<'a> {
}

let rust_call = |args: Vec<TokenStream>| {
let allow_threads = self.allow_threads.is_some();
let call = if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
Expand Down Expand Up @@ -504,6 +508,7 @@ impl<'a> FnSpec<'a> {
_pyo3::intern!(py, stringify!(#python_name)),
#qualname_prefix,
#throw_callback,
#allow_threads,
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
)
}};
Expand All @@ -515,6 +520,25 @@ impl<'a> FnSpec<'a> {
}};
}
call
} else if allow_threads {
let (self_arg_name, self_arg_decl) = if self_arg.is_empty() {
(quote!(), quote!())
} else {
(quote!(__self), quote! { let __self = #self_arg; })
};
let arg_names: Vec<Ident> = (0..args.len())
.map(|i| format_ident!("__arg{}", i))
.collect();
let arg_decls: Vec<TokenStream> = args
.into_iter()
.zip(&arg_names)
.map(|(arg, name)| quote! { let #name = #arg; })
.collect();
quote! {{
#self_arg_decl
#(#arg_decls)*
py.allow_threads(|| function(#self_arg_name #(#arg_names),*))
}}
} else {
quote! { function(#self_arg #(#args),*) }
};
Expand Down
12 changes: 10 additions & 2 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ pub struct PyFunctionOptions {
pub signature: Option<SignatureAttribute>,
pub text_signature: Option<TextSignatureAttribute>,
pub krate: Option<CrateAttribute>,
pub allow_threads: Option<attributes::kw::allow_threads>,
}

impl Parse for PyFunctionOptions {
Expand All @@ -99,7 +100,8 @@ impl Parse for PyFunctionOptions {

while !input.is_empty() {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name)
if lookahead.peek(attributes::kw::allow_threads)
|| lookahead.peek(attributes::kw::name)
|| lookahead.peek(attributes::kw::pass_module)
|| lookahead.peek(attributes::kw::signature)
|| lookahead.peek(attributes::kw::text_signature)
Expand All @@ -121,6 +123,7 @@ impl Parse for PyFunctionOptions {
}

pub enum PyFunctionOption {
AllowThreads(attributes::kw::allow_threads),
Name(NameAttribute),
PassModule(attributes::kw::pass_module),
Signature(SignatureAttribute),
Expand All @@ -131,7 +134,9 @@ pub enum PyFunctionOption {
impl Parse for PyFunctionOption {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
if lookahead.peek(attributes::kw::allow_threads) {
input.parse().map(PyFunctionOption::AllowThreads)
} else if lookahead.peek(attributes::kw::name) {
input.parse().map(PyFunctionOption::Name)
} else if lookahead.peek(attributes::kw::pass_module) {
input.parse().map(PyFunctionOption::PassModule)
Expand Down Expand Up @@ -171,6 +176,7 @@ impl PyFunctionOptions {
}
for attr in attrs {
match attr {
PyFunctionOption::AllowThreads(allow_threads) => set_option!(allow_threads),
PyFunctionOption::Name(name) => set_option!(name),
PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
PyFunctionOption::Signature(signature) => set_option!(signature),
Expand Down Expand Up @@ -198,6 +204,7 @@ pub fn impl_wrap_pyfunction(
) -> syn::Result<TokenStream> {
check_generic(&func.sig)?;
let PyFunctionOptions {
allow_threads,
pass_module,
name,
signature,
Expand Down Expand Up @@ -247,6 +254,7 @@ pub fn impl_wrap_pyfunction(
signature,
output: ty,
text_signature,
allow_threads,
asyncness: func.sig.asyncness,
unsafety: func.sig.unsafety,
deprecations: Deprecations::new(),
Expand Down
1 change: 1 addition & 0 deletions pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ pub fn pymethods(attr: TokenStream, input: TokenStream) -> TokenStream {
/// | `#[pyo3(name = "...")]` | Defines the name of the function in Python. |
/// | `#[pyo3(text_signature = "...")]` | Defines the `__text_signature__` attribute of the function in Python. |
/// | `#[pyo3(pass_module)]` | Passes the module containing the function as a `&PyModule` first argument to the function. |
/// | `#[pyo3(allow_threads)]` | Release the GIL in the function body, or each time the returned future is polled for `async fn` |
///
/// For more on exposing functions see the [function section of the guide][1].
///
Expand Down
11 changes: 10 additions & 1 deletion src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub struct Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
waker: Option<Arc<AsyncioWaker>>,
}
Expand All @@ -49,6 +50,7 @@ impl Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: F,
) -> Self
where
Expand All @@ -65,6 +67,7 @@ impl Coroutine {
name,
qualname_prefix,
throw_callback,
allow_threads,
future: Some(Box::pin(wrap)),
waker: None,
}
Expand Down Expand Up @@ -98,7 +101,13 @@ impl Coroutine {
let waker = Waker::from(self.waker.clone().unwrap());
// poll the Rust future and forward its results if ready
// polling is UnwindSafe because the future is dropped in case of panic
let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
let poll = || {
if self.allow_threads {
py.allow_threads(|| future_rs.as_mut().poll(&mut Context::from_waker(&waker)))
} else {
future_rs.as_mut().poll(&mut Context::from_waker(&waker))
}
};
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
Ok(Poll::Ready(res)) => {
self.close();
Expand Down
42 changes: 41 additions & 1 deletion src/gil.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ fn decrement_gil_count() {
#[cfg(test)]
mod tests {
use super::{gil_is_acquired, GILPool, GIL_COUNT, OWNED_OBJECTS, POOL};
use crate::{ffi, gil, PyObject, Python, ToPyObject};
use crate::{ffi, gil, py_run, wrap_pyfunction, PyObject, Python, ToPyObject};
#[cfg(not(target_arch = "wasm32"))]
use parking_lot::{const_mutex, Condvar, Mutex};
use std::ptr::NonNull;
Expand Down Expand Up @@ -925,4 +925,44 @@ mod tests {
POOL.update_counts(py);
})
}

#[test]
fn allow_threads_fn() {
#[crate::pyfunction(allow_threads, crate = "crate")]
fn without_gil() {
GIL_COUNT.with(|c| assert_eq!(c.get(), 0));
}
Python::with_gil(|gil| {
let without_gil = wrap_pyfunction!(without_gil, gil).unwrap();
py_run!(gil, without_gil, "without_gil()");
})
}

#[test]
fn allow_threads_async_fn() {
#[crate::pyfunction(allow_threads, crate = "crate")]
async fn without_gil() {
use std::task::Poll;
GIL_COUNT.with(|c| assert_eq!(c.get(), 0));
let mut ready = false;
futures::future::poll_fn(|cx| {
if ready {
return Poll::Ready(());
}
ready = true;
cx.waker().wake_by_ref();
Poll::Pending
})
.await;
GIL_COUNT.with(|c| assert_eq!(c.get(), 0));
}
Python::with_gil(|gil| {
let without_gil = wrap_pyfunction!(without_gil, gil).unwrap();
py_run!(
gil,
without_gil,
"import asyncio; asyncio.run(without_gil())"
);
})
}
}
9 changes: 8 additions & 1 deletion src/impl_/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@ pub fn new_coroutine<F, T, E>(
name: &PyString,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: F,
) -> Coroutine
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
E: Into<PyErr>,
{
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
Coroutine::new(
Some(name.into()),
qualname_prefix,
throw_callback,
allow_threads,
future,
)
}

fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
Expand Down
2 changes: 1 addition & 1 deletion tests/ui/invalid_pyfunction_signatures.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ error: expected argument from function definition `y` but got argument `x`
13 | #[pyo3(signature = (x))]
| ^

error: expected one of: `name`, `pass_module`, `signature`, `text_signature`, `crate`
error: expected one of: `allow_threads`, `name`, `pass_module`, `signature`, `text_signature`, `crate`
--> tests/ui/invalid_pyfunction_signatures.rs:18:14
|
18 | #[pyfunction(x)]
Expand Down

0 comments on commit 32c2050

Please sign in to comment.