Skip to content

Commit

Permalink
deprecate gil-refs in "self" position (#3943)
Browse files Browse the repository at this point in the history
* deprecate gil-refs in "self" position

* feature gate explicit gil-ref tests

* fix MSRV

* adjust bracketing

---------

Co-authored-by: David Hewitt <[email protected]>
  • Loading branch information
Icxolu and davidhewitt authored Mar 9, 2024
1 parent 14d1d2a commit 908e661
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 20 deletions.
72 changes: 61 additions & 11 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,24 @@ impl FnType {
}
FnType::FnClass(span) | FnType::FnNewClass(span) => {
let py = syn::Ident::new("py", Span::call_site());
let slf: Ident = syn::Ident::new("_slf", Span::call_site());
let slf: Ident = syn::Ident::new("_slf_ref", Span::call_site());
let pyo3_path = pyo3_path.to_tokens_spanned(*span);
quote_spanned! { *span =>
#[allow(clippy::useless_conversion)]
::std::convert::Into::into(
#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf.cast())
#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(#py, &*(#slf as *const _ as *const *mut _))
.downcast_unchecked::<#pyo3_path::types::PyType>()
),
}
}
FnType::FnModule(span) => {
let py = syn::Ident::new("py", Span::call_site());
let slf: Ident = syn::Ident::new("_slf", Span::call_site());
let slf: Ident = syn::Ident::new("_slf_ref", Span::call_site());
let pyo3_path = pyo3_path.to_tokens_spanned(*span);
quote_spanned! { *span =>
#[allow(clippy::useless_conversion)]
::std::convert::Into::into(
#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf.cast())
#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(#py, &*(#slf as *const _ as *const *mut _))
.downcast_unchecked::<#pyo3_path::types::PyModule>()
),
}
Expand Down Expand Up @@ -519,7 +519,9 @@ impl<'a> FnSpec<'a> {
);
}

let rust_call = |args: Vec<TokenStream>, holders: &mut Vec<TokenStream>| {
let rust_call = |args: Vec<TokenStream>,
self_e: &syn::Ident,
holders: &mut Vec<TokenStream>| {
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise, holders, ctx);

let call = if self.asyncness.is_some() {
Expand All @@ -538,6 +540,7 @@ impl<'a> FnSpec<'a> {
holders.pop().unwrap(); // does not actually use holder created by `self_arg`

quote! {{
#self_e = #pyo3_path::impl_::pymethods::Extractor::<()>::new();
let __guard = #pyo3_path::impl_::coroutine::RefGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))?;
async move { function(&__guard, #(#args),*).await }
}}
Expand All @@ -546,11 +549,25 @@ impl<'a> FnSpec<'a> {
holders.pop().unwrap(); // does not actually use holder created by `self_arg`

quote! {{
#self_e = #pyo3_path::impl_::pymethods::Extractor::<()>::new();
let mut __guard = #pyo3_path::impl_::coroutine::RefMutGuard::<#cls>::new(&#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &_slf))?;
async move { function(&mut __guard, #(#args),*).await }
}}
}
_ => quote! { function(#self_arg #(#args),*) },
_ => {
if self_arg.is_empty() {
quote! {{
#self_e = #pyo3_path::impl_::pymethods::Extractor::<()>::new();
function(#(#args),*)
}}
} else {
quote! { function({
let (self_arg, e) = #pyo3_path::impl_::pymethods::inspect_type(#self_arg);
#self_e = e;
self_arg
}, #(#args),*) }
}
}
};
let mut call = quote! {{
let future = #future;
Expand All @@ -569,10 +586,24 @@ impl<'a> FnSpec<'a> {
}};
}
call
} else if self_arg.is_empty() {
quote! {{
#self_e = #pyo3_path::impl_::pymethods::Extractor::<()>::new();
function(#(#args),*)
}}
} else {
quote! { function(#self_arg #(#args),*) }
quote! {
function({
let (self_arg, e) = #pyo3_path::impl_::pymethods::inspect_type(#self_arg);
#self_e = e;
self_arg
}, #(#args),*)
}
};
quotes::map_result_into_ptr(quotes::ok_wrap(call, ctx), ctx)
(
quotes::map_result_into_ptr(quotes::ok_wrap(call, ctx), ctx),
self_arg.span(),
)
};

let func_name = &self.name;
Expand All @@ -582,6 +613,7 @@ impl<'a> FnSpec<'a> {
quote!(#func_name)
};

let self_e = syn::Ident::new("self_e", Span::call_site());
Ok(match self.convention {
CallingConvention::Noargs => {
let mut holders = Vec::new();
Expand All @@ -599,24 +631,32 @@ impl<'a> FnSpec<'a> {
}
})
.collect();
let call = rust_call(args, &mut holders);
let (call, self_arg_span) = rust_call(args, &self_e, &mut holders);
let extract_gil_ref =
quote_spanned! { self_arg_span => #self_e.extract_gil_ref(); };

quote! {
unsafe fn #ident<'py>(
py: #pyo3_path::Python<'py>,
_slf: *mut #pyo3_path::ffi::PyObject,
) -> #pyo3_path::PyResult<*mut #pyo3_path::ffi::PyObject> {
let _slf_ref = &_slf;
let function = #rust_name; // Shadow the function name to avoid #3017
let #self_e;
#( #holders )*
let result = #call;
#extract_gil_ref
result
}
}
}
CallingConvention::Fastcall => {
let mut holders = Vec::new();
let (arg_convert, args) = impl_arg_params(self, cls, true, &mut holders, ctx)?;
let call = rust_call(args, &mut holders);
let (call, self_arg_span) = rust_call(args, &self_e, &mut holders);
let extract_gil_ref =
quote_spanned! { self_arg_span => #self_e.extract_gil_ref(); };

quote! {
unsafe fn #ident<'py>(
py: #pyo3_path::Python<'py>,
Expand All @@ -625,29 +665,38 @@ impl<'a> FnSpec<'a> {
_nargs: #pyo3_path::ffi::Py_ssize_t,
_kwnames: *mut #pyo3_path::ffi::PyObject
) -> #pyo3_path::PyResult<*mut #pyo3_path::ffi::PyObject> {
let _slf_ref = &_slf;
let function = #rust_name; // Shadow the function name to avoid #3017
let #self_e;
#arg_convert
#( #holders )*
let result = #call;
#extract_gil_ref
result
}
}
}
CallingConvention::Varargs => {
let mut holders = Vec::new();
let (arg_convert, args) = impl_arg_params(self, cls, false, &mut holders, ctx)?;
let call = rust_call(args, &mut holders);
let (call, self_arg_span) = rust_call(args, &self_e, &mut holders);
let extract_gil_ref =
quote_spanned! { self_arg_span => #self_e.extract_gil_ref(); };

quote! {
unsafe fn #ident<'py>(
py: #pyo3_path::Python<'py>,
_slf: *mut #pyo3_path::ffi::PyObject,
_args: *mut #pyo3_path::ffi::PyObject,
_kwargs: *mut #pyo3_path::ffi::PyObject
) -> #pyo3_path::PyResult<*mut #pyo3_path::ffi::PyObject> {
let _slf_ref = &_slf;
let function = #rust_name; // Shadow the function name to avoid #3017
let #self_e;
#arg_convert
#( #holders )*
let result = #call;
#extract_gil_ref
result
}
}
Expand All @@ -667,6 +716,7 @@ impl<'a> FnSpec<'a> {
_kwargs: *mut #pyo3_path::ffi::PyObject
) -> #pyo3_path::PyResult<*mut #pyo3_path::ffi::PyObject> {
use #pyo3_path::callback::IntoPyCallbackOutput;
let _slf_ref = &_slf;
let function = #rust_name; // Shadow the function name to avoid #3017
#arg_convert
#( #holders )*
Expand Down
8 changes: 6 additions & 2 deletions tests/test_class_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ impl ClassWithFromPyWithMethods {
}

#[classmethod]
#[cfg(feature = "gil-refs")]
fn classmethod_gil_ref(
_cls: &PyType,
#[pyo3(from_py_with = "PyAny::len")] argument: usize,
Expand All @@ -324,16 +325,19 @@ impl ClassWithFromPyWithMethods {
fn test_pymethods_from_py_with() {
Python::with_gil(|py| {
let instance = Py::new(py, ClassWithFromPyWithMethods {}).unwrap();
let has_gil_refs = cfg!(feature = "gil-refs");

py_run!(
py,
instance,
instance
has_gil_refs,
r#"
arg = {1: 1, 2: 3}
assert instance.instance_method(arg) == 2
assert instance.classmethod(arg) == 2
assert instance.classmethod_gil_ref(arg) == 2
if has_gil_refs:
assert instance.classmethod_gil_ref(arg) == 2
assert instance.staticmethod(arg) == 2
"#
);
Expand Down
1 change: 1 addition & 0 deletions tests/test_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ impl ClassMethod {

#[classmethod]
/// Test class method.
#[cfg(feature = "gil-refs")]
fn method_gil_ref(cls: &PyType) -> PyResult<String> {
Ok(format!("{}.method()!", cls.qualname()?))
}
Expand Down
8 changes: 7 additions & 1 deletion tests/test_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ fn module_with_functions(m: &Bound<'_, PyModule>) -> PyResult<()> {

#[pyfn(m)]
#[pyo3(pass_module)]
fn with_module(module: &PyModule) -> PyResult<&str> {
fn with_module<'py>(module: &Bound<'py, PyModule>) -> PyResult<Bound<'py, PyString>> {
module.name()
}

Expand Down Expand Up @@ -373,6 +373,7 @@ fn pyfunction_with_module<'py>(module: &Bound<'py, PyModule>) -> PyResult<Bound<

#[pyfunction]
#[pyo3(pass_module)]
#[cfg(feature = "gil-refs")]
fn pyfunction_with_module_gil_ref(module: &PyModule) -> PyResult<&str> {
module.name()
}
Expand Down Expand Up @@ -427,19 +428,22 @@ fn pyfunction_with_module_and_args_kwargs<'py>(

#[pyfunction]
#[pyo3(pass_module)]
#[cfg(feature = "gil-refs")]
fn pyfunction_with_pass_module_in_attribute(module: &PyModule) -> PyResult<&str> {
module.name()
}

#[pymodule]
fn module_with_functions_with_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(pyfunction_with_module, m)?)?;
#[cfg(feature = "gil-refs")]
m.add_function(wrap_pyfunction!(pyfunction_with_module_gil_ref, m)?)?;
m.add_function(wrap_pyfunction!(pyfunction_with_module_owned, m)?)?;
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_py, m)?)?;
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_arg, m)?)?;
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_default_arg, m)?)?;
m.add_function(wrap_pyfunction!(pyfunction_with_module_and_args_kwargs, m)?)?;
#[cfg(feature = "gil-refs")]
m.add_function(wrap_pyfunction!(
pyfunction_with_pass_module_in_attribute,
m
Expand All @@ -457,6 +461,7 @@ fn test_module_functions_with_module() {
m,
"m.pyfunction_with_module() == 'module_with_functions_with_module'"
);
#[cfg(feature = "gil-refs")]
py_assert!(
py,
m,
Expand Down Expand Up @@ -484,6 +489,7 @@ fn test_module_functions_with_module() {
"m.pyfunction_with_module_and_args_kwargs(1, x=1, y=2) \
== ('module_with_functions_with_module', 1, 2)"
);
#[cfg(feature = "gil-refs")]
py_assert!(
py,
m,
Expand Down
23 changes: 23 additions & 0 deletions tests/ui/deprecations.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![deny(deprecated)]

use pyo3::prelude::*;
use pyo3::types::{PyString, PyType};

#[pyclass]
struct MyClass;
Expand All @@ -11,10 +12,32 @@ impl MyClass {
fn new() -> Self {
Self
}

#[classmethod]
fn cls_method_gil_ref(_cls: &PyType) {}

#[classmethod]
fn cls_method_bound(_cls: &Bound<'_, PyType>) {}

fn method_gil_ref(_slf: &PyCell<Self>) {}

fn method_bound(_slf: &Bound<'_, Self>) {}
}

fn main() {}

#[pyfunction]
#[pyo3(pass_module)]
fn pyfunction_with_module<'py>(module: &Bound<'py, PyModule>) -> PyResult<Bound<'py, PyString>> {
module.name()
}

#[pyfunction]
#[pyo3(pass_module)]
fn pyfunction_with_module_gil_ref(module: &PyModule) -> PyResult<&str> {
module.name()
}

#[pyfunction]
fn double(x: usize) -> usize {
x * 2
Expand Down
36 changes: 30 additions & 6 deletions tests/ui/deprecations.stderr
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
error: use of deprecated constant `pyo3::impl_::deprecations::PYMETHODS_NEW_DEPRECATED_FORM`: use `#[new]` instead of `#[__new__]`
--> tests/ui/deprecations.rs:10:7
--> tests/ui/deprecations.rs:11:7
|
10 | #[__new__]
11 | #[__new__]
| ^^^^^^^
|
note: the lint level is defined here
Expand All @@ -10,14 +10,38 @@ note: the lint level is defined here
1 | #![deny(deprecated)]
| ^^^^^^^^^^

error: use of deprecated struct `pyo3::PyCell`: `PyCell` was merged into `Bound`, use that instead; see the migration guide for more info
--> tests/ui/deprecations.rs:22:30
|
22 | fn method_gil_ref(_slf: &PyCell<Self>) {}
| ^^^^^^

error: use of deprecated method `pyo3::methods::Extractor::<T>::extract_gil_ref`: use `&Bound<'_, T>` instead for this function argument
--> tests/ui/deprecations.rs:17:33
|
17 | fn cls_method_gil_ref(_cls: &PyType) {}
| ^

error: use of deprecated method `pyo3::methods::Extractor::<T>::extract_gil_ref`: use `&Bound<'_, T>` instead for this function argument
--> tests/ui/deprecations.rs:22:29
|
22 | fn method_gil_ref(_slf: &PyCell<Self>) {}
| ^

error: use of deprecated method `pyo3::methods::Extractor::<T>::extract_gil_ref`: use `&Bound<'_, T>` instead for this function argument
--> tests/ui/deprecations.rs:37:43
|
37 | fn pyfunction_with_module_gil_ref(module: &PyModule) -> PyResult<&str> {
| ^

error: use of deprecated method `pyo3::methods::Extractor::<T>::extract_gil_ref`: use `&Bound<'_, T>` instead for this function argument
--> tests/ui/deprecations.rs:24:19
--> tests/ui/deprecations.rs:47:19
|
24 | fn module_gil_ref(m: &PyModule) -> PyResult<()> {
47 | fn module_gil_ref(m: &PyModule) -> PyResult<()> {
| ^

error: use of deprecated method `pyo3::methods::Extractor::<T>::extract_gil_ref`: use `&Bound<'_, T>` instead for this function argument
--> tests/ui/deprecations.rs:30:57
--> tests/ui/deprecations.rs:53:57
|
30 | fn module_gil_ref_with_explicit_py_arg(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
53 | fn module_gil_ref_with_explicit_py_arg(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
| ^

0 comments on commit 908e661

Please sign in to comment.