Skip to content

Commit

Permalink
Merge pull request #3609 from wyfo/async_receiver
Browse files Browse the repository at this point in the history
feat: allow async methods to accept `&self`/`&mut self`
  • Loading branch information
davidhewitt authored Dec 7, 2023
2 parents 4baf023 + f34c70c commit 07726ae
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 22 deletions.
3 changes: 1 addition & 2 deletions guide/src/async-await.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ Resulting future of an `async fn` decorated by `#[pyfunction]` must be `Send + '

As a consequence, `async fn` parameters and return types must also be `Send + 'static`, so it is not possible to have a signature like `async fn does_not_compile(arg: &PyAny, py: Python<'_>) -> &PyAny`.

It also means that methods cannot use `&self`/`&mut self`, *but this restriction should be dropped in the future.*

However, there is an exception for method receiver, so async methods can accept `&self`/`&mut self`

## Implicit GIL holding

Expand Down
1 change: 1 addition & 0 deletions newsfragments/3609.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow async methods to accept `&self`/`&mut self`
48 changes: 31 additions & 17 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
use std::fmt::Display;

use crate::attributes::{TextSignatureAttribute, TextSignatureAttributeValue};
use crate::deprecations::{Deprecation, Deprecations};
use crate::params::impl_arg_params;
use crate::pyfunction::{FunctionSignature, PyFunctionArgPyO3Attributes};
use crate::pyfunction::{PyFunctionOptions, SignatureAttribute};
use crate::quotes;
use crate::utils::{self, PythonDoc};
use proc_macro2::{Span, TokenStream};
use quote::ToTokens;
use quote::{quote, quote_spanned};
use syn::ext::IdentExt;
use syn::spanned::Spanned;
use syn::{Ident, Result};
use quote::{quote, quote_spanned, ToTokens};
use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};

use crate::{
attributes::{TextSignatureAttribute, TextSignatureAttributeValue},
deprecations::{Deprecation, Deprecations},
params::impl_arg_params,
pyfunction::{
FunctionSignature, PyFunctionArgPyO3Attributes, PyFunctionOptions, SignatureAttribute,
},
quotes,
utils::{self, PythonDoc},
};

#[derive(Clone, Debug)]
pub struct FnArg<'a> {
Expand Down Expand Up @@ -473,8 +474,7 @@ impl<'a> FnSpec<'a> {
}

let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
let call = if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
} else {
Expand All @@ -485,8 +485,19 @@ impl<'a> FnSpec<'a> {
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
None => quote!(None),
};
call = quote! {{
let future = #call;
let future = match self.tp {
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => quote! {{
let __guard = _pyo3::impl_::coroutine::RefGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
async move { function(&__guard, #(#args),*).await }
}},
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => quote! {{
let mut __guard = _pyo3::impl_::coroutine::RefMutGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
async move { function(&mut __guard, #(#args),*).await }
}},
_ => quote! { function(#self_arg #(#args),*) },
};
let mut call = quote! {{
let future = #future;
_pyo3::impl_::coroutine::new_coroutine(
_pyo3::intern!(py, stringify!(#python_name)),
#qualname_prefix,
Expand All @@ -501,7 +512,10 @@ impl<'a> FnSpec<'a> {
#call
}};
}
}
call
} else {
quote! { function(#self_arg #(#args),*) }
};
quotes::map_result_into_ptr(quotes::ok_wrap(call))
};

Expand Down
74 changes: 71 additions & 3 deletions src/impl_/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
use std::future::Future;
use std::{
future::Future,
mem,
ops::{Deref, DerefMut},
};

use crate::coroutine::cancel::ThrowCallback;
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};
use crate::{
coroutine::{cancel::ThrowCallback, Coroutine},
pyclass::boolean_struct::False,
types::PyString,
IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, Python,
};

pub fn new_coroutine<F, T, E>(
name: &PyString,
Expand All @@ -16,3 +24,63 @@ where
{
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
}

fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
// SAFETY: Py<T> can be casted as *const PyCell<T>
unsafe { &*(obj.as_ptr() as *const PyCell<T>) }.get_ptr()
}

pub struct RefGuard<T: PyClass>(Py<T>);

impl<T: PyClass> RefGuard<T> {
pub fn new(obj: &PyAny) -> PyResult<Self> {
let owned: Py<T> = obj.extract()?;
mem::forget(owned.try_borrow(obj.py())?);
Ok(RefGuard(owned))
}
}

impl<T: PyClass> Deref for RefGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
// SAFETY: `RefGuard` has been built from `PyRef` and provides the same guarantees
unsafe { &*get_ptr(&self.0) }
}
}

impl<T: PyClass> Drop for RefGuard<T> {
fn drop(&mut self) {
Python::with_gil(|gil| self.0.as_ref(gil).release_ref())
}
}

pub struct RefMutGuard<T: PyClass<Frozen = False>>(Py<T>);

impl<T: PyClass<Frozen = False>> RefMutGuard<T> {
pub fn new(obj: &PyAny) -> PyResult<Self> {
let owned: Py<T> = obj.extract()?;
mem::forget(owned.try_borrow_mut(obj.py())?);
Ok(RefMutGuard(owned))
}
}

impl<T: PyClass<Frozen = False>> Deref for RefMutGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
// SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees
unsafe { &*get_ptr(&self.0) }
}
}

impl<T: PyClass<Frozen = False>> DerefMut for RefMutGuard<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
// SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees
unsafe { &mut *get_ptr(&self.0) }
}
}

impl<T: PyClass<Frozen = False>> Drop for RefMutGuard<T> {
fn drop(&mut self) {
Python::with_gil(|gil| self.0.as_ref(gil).release_mut())
}
}
10 changes: 10 additions & 0 deletions src/pycell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,16 @@ impl<T: PyClass> PyCell<T> {
#[allow(clippy::useless_conversion)]
offset.try_into().expect("offset should fit in Py_ssize_t")
}

#[cfg(feature = "macros")]
pub(crate) fn release_ref(&self) {
self.borrow_checker().release_borrow();
}

#[cfg(feature = "macros")]
pub(crate) fn release_mut(&self) {
self.borrow_checker().release_borrow_mut();
}
}

impl<T: PyClassImpl> PyCell<T> {
Expand Down
53 changes: 53 additions & 0 deletions tests/test_coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,56 @@ fn coroutine_panic() {
py_run!(gil, panic, &handle_windows(test));
})
}

#[test]
fn test_async_method_receiver() {
#[pyclass]
struct Counter(usize);
#[pymethods]
impl Counter {
#[new]
fn new() -> Self {
Self(0)
}
async fn get(&self) -> usize {
self.0
}
async fn incr(&mut self) -> usize {
self.0 += 1;
self.0
}
}
Python::with_gil(|gil| {
let test = r#"
import asyncio
obj = Counter()
coro1 = obj.get()
coro2 = obj.get()
try:
obj.incr() # borrow checking should fail
except RuntimeError as err:
pass
else:
assert False
assert asyncio.run(coro1) == 0
coro2.close()
coro3 = obj.incr()
try:
obj.incr() # borrow checking should fail
except RuntimeError as err:
pass
else:
assert False
try:
obj.get() # borrow checking should fail
except RuntimeError as err:
pass
else:
assert False
assert asyncio.run(coro3) == 1
"#;
let locals = [("Counter", gil.get_type::<Counter>())].into_py_dict(gil);
py_run!(gil, *locals, test);
})
}

0 comments on commit 07726ae

Please sign in to comment.