Skip to content

Commit 74d9d23

Browse files
authored
async method should allow args not only receiver (#4015)
* async method should allow args not only receiver * add changelog md
1 parent 4d033c4 commit 74d9d23

File tree

3 files changed

+53
-3
lines changed

3 files changed

+53
-3
lines changed

newsfragments/4015.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix the bug that an async `#[pymethod]` with receiver can't have any other args.

pyo3-macros-backend/src/method.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::fmt::Display;
22

33
use proc_macro2::{Span, TokenStream};
4-
use quote::{quote, quote_spanned, ToTokens};
4+
use quote::{format_ident, quote, quote_spanned, ToTokens};
55
use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};
66

77
use crate::utils::Ctx;
@@ -518,17 +518,33 @@ impl<'a> FnSpec<'a> {
518518
Some(cls) => quote!(Some(<#cls as #pyo3_path::PyTypeInfo>::NAME)),
519519
None => quote!(None),
520520
};
521+
let evaluate_args = || -> (Vec<Ident>, TokenStream) {
522+
let mut arg_names = Vec::with_capacity(args.len());
523+
let mut evaluate_arg = quote! {};
524+
for arg in &args {
525+
let arg_name = format_ident!("arg_{}", arg_names.len());
526+
arg_names.push(arg_name.clone());
527+
evaluate_arg.extend(quote! {
528+
let #arg_name = #arg
529+
});
530+
}
531+
(arg_names, evaluate_arg)
532+
};
521533
let future = match self.tp {
522534
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => {
535+
let (arg_name, evaluate_arg) = evaluate_args();
523536
quote! {{
537+
#evaluate_arg;
524538
let __guard = #pyo3_path::impl_::coroutine::RefGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))?;
525-
async move { function(&__guard, #(#args),*).await }
539+
async move { function(&__guard, #(#arg_name),*).await }
526540
}}
527541
}
528542
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => {
543+
let (arg_name, evaluate_arg) = evaluate_args();
529544
quote! {{
545+
#evaluate_arg;
530546
let mut __guard = #pyo3_path::impl_::coroutine::RefMutGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))?;
531-
async move { function(&mut __guard, #(#args),*).await }
547+
async move { function(&mut __guard, #(#arg_name),*).await }
532548
}}
533549
}
534550
_ => {

tests/test_coroutine.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,39 @@ fn coroutine_panic() {
245245
})
246246
}
247247

248+
#[test]
249+
fn test_async_method_receiver_with_other_args() {
250+
#[pyclass]
251+
struct Value(i32);
252+
#[pymethods]
253+
impl Value {
254+
#[new]
255+
fn new() -> Self {
256+
Self(0)
257+
}
258+
async fn get_value_plus_with(&self, v: i32) -> i32 {
259+
self.0 + v
260+
}
261+
async fn set_value(&mut self, new_value: i32) -> i32 {
262+
self.0 = new_value;
263+
self.0
264+
}
265+
}
266+
267+
Python::with_gil(|gil| {
268+
let test = r#"
269+
import asyncio
270+
271+
v = Value()
272+
assert asyncio.run(v.get_value_plus_with(3)) == 3
273+
assert asyncio.run(v.set_value(10)) == 10
274+
assert asyncio.run(v.get_value_plus_with(1)) == 11
275+
"#;
276+
let locals = [("Value", gil.get_type_bound::<Value>())].into_py_dict_bound(gil);
277+
py_run!(gil, *locals, test);
278+
});
279+
}
280+
248281
#[test]
249282
fn test_async_method_receiver() {
250283
#[pyclass]

0 commit comments

Comments
 (0)