Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support Bound for classmethod and pass_module #3831

Merged
merged 4 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ This is the equivalent of the Python decorator `@classmethod`.
#[pymethods]
impl MyClass {
#[classmethod]
fn cls_method(cls: &PyType) -> PyResult<i32> {
fn cls_method(cls: &Bound<'_, PyType>) -> PyResult<i32> {
Ok(10)
}
}
Expand Down Expand Up @@ -719,10 +719,10 @@ To create a constructor which takes a positional class argument, you can combine
impl BaseClass {
#[new]
#[classmethod]
fn py_new<'p>(cls: &'p PyType, py: Python<'p>) -> PyResult<Self> {
fn py_new(cls: &Bound<'_, PyType>) -> PyResult<Self> {
// Get an abstract attribute (presumably) declared on a subclass of this class.
let subclass_attr = cls.getattr("a_class_attr")?;
Ok(Self(subclass_attr.to_object(py)))
let subclass_attr: Bound<'_, PyAny> = cls.getattr("a_class_attr")?;
Ok(Self(subclass_attr.unbind()))
}
}
```
Expand Down Expand Up @@ -928,7 +928,7 @@ impl MyClass {
// similarly for classmethod arguments, use $cls
#[classmethod]
#[pyo3(text_signature = "($cls, e, f)")]
fn my_class_method(cls: &PyType, e: i32, f: i32) -> i32 {
fn my_class_method(cls: &Bound<'_, PyType>, e: i32, f: i32) -> i32 {
e + f
}
#[staticmethod]
Expand Down
3 changes: 2 additions & 1 deletion guide/src/function.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ The `#[pyo3]` attribute can be used to modify properties of the generated Python

```rust
use pyo3::prelude::*;
use pyo3::types::PyString;

#[pyfunction]
#[pyo3(pass_module)]
fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> {
fn pyfunction_with_module<'py>(module: &Bound<'py, PyModule>) -> PyResult<Bound<'py, PyString>> {
module.name()
}

Expand Down
14 changes: 11 additions & 3 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,21 @@ impl FnType {
let slf: Ident = syn::Ident::new("_slf", Span::call_site());
quote_spanned! { *span =>
#[allow(clippy::useless_conversion)]
::std::convert::Into::into(_pyo3::types::PyType::from_type_ptr(#py, #slf.cast())),
::std::convert::Into::into(
_pyo3::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf.cast())
.downcast_unchecked::<_pyo3::types::PyType>()
),
}
}
FnType::FnModule(span) => {
let py = syn::Ident::new("py", Span::call_site());
let slf: Ident = syn::Ident::new("_slf", Span::call_site());
quote_spanned! { *span =>
#[allow(clippy::useless_conversion)]
::std::convert::Into::into(py.from_borrowed_ptr::<_pyo3::types::PyModule>(_slf)),
::std::convert::Into::into(
_pyo3::impl_::pymethods::BoundRef::ref_from_ptr(#py, &#slf.cast())
.downcast_unchecked::<_pyo3::types::PyModule>()
),
}
}
}
Expand Down Expand Up @@ -409,7 +417,7 @@ impl<'a> FnSpec<'a> {
// will error on incorrect type.
Some(syn::FnArg::Typed(first_arg)) => first_arg.ty.span(),
Some(syn::FnArg::Receiver(_)) | None => bail_spanned!(
sig.paren_token.span.join() => "Expected `&PyType` or `Py<PyType>` as the first argument to `#[classmethod]`"
sig.paren_token.span.join() => "Expected `&Bound<PyType>` or `Py<PyType>` as the first argument to `#[classmethod]`"
),
};
FnType::FnClass(span)
Expand Down
20 changes: 20 additions & 0 deletions pytests/src/pyclasses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ struct AssertingBaseClass;

#[pymethods]
impl AssertingBaseClass {
#[new]
#[classmethod]
fn new(cls: &Bound<'_, PyType>, expected_type: Bound<'_, PyType>) -> PyResult<Self> {
if !cls.is(&expected_type) {
return Err(PyValueError::new_err(format!(
"{:?} != {:?}",
cls, expected_type
)));
}
Ok(Self)
}
}

#[pyclass(subclass)]
#[derive(Clone, Debug)]
struct AssertingBaseClassGilRef;

#[pymethods]
impl AssertingBaseClassGilRef {
#[new]
#[classmethod]
fn new(cls: &PyType, expected_type: &PyType) -> PyResult<Self> {
Expand All @@ -65,6 +84,7 @@ pub fn pyclasses(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<EmptyClass>()?;
m.add_class::<PyClassIter>()?;
m.add_class::<AssertingBaseClass>()?;
m.add_class::<AssertingBaseClassGilRef>()?;
m.add_class::<ClassWithoutConstructor>()?;
Ok(())
}
11 changes: 11 additions & 0 deletions pytests/tests/test_pyclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def test_new_classmethod():
_ = AssertingSubClass(expected_type=str)


def test_new_classmethod_gil_ref():
class AssertingSubClass(pyclasses.AssertingBaseClassGilRef):
pass

# The `AssertingBaseClass` constructor errors if it is not passed the
# relevant subclass.
_ = AssertingSubClass(expected_type=AssertingSubClass)
with pytest.raises(ValueError):
_ = AssertingSubClass(expected_type=str)


class ClassWithoutConstructorPy:
def __new__(cls):
raise TypeError("No constructor defined")
Expand Down
53 changes: 52 additions & 1 deletion src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ use crate::exceptions::PyStopAsyncIteration;
use crate::gil::LockGIL;
use crate::impl_::panic::PanicTrap;
use crate::internal_tricks::extract_c_string;
use crate::types::{any::PyAnyMethods, PyModule, PyType};
use crate::{
ffi, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit, Python,
ffi, Bound, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, PyTraverseError, PyVisit,
Python,
};
use std::borrow::Cow;
use std::ffi::CStr;
Expand Down Expand Up @@ -466,3 +468,52 @@ pub trait AsyncIterResultOptionKind {
}

impl<Value, Error> AsyncIterResultOptionKind for Result<Option<Value>, Error> {}

/// Used in `#[classmethod]` to pass the class object to the method
/// and also in `#[pyfunction(pass_module)]`.
///
/// This is a wrapper to avoid implementing `From<Bound>` for GIL Refs.
///
/// Once the GIL Ref API is fully removed, it should be possible to simplify
/// this to just `&'a Bound<'py, T>` and `From` implementations.
pub struct BoundRef<'a, 'py, T>(pub &'a Bound<'py, T>);

impl<'a, 'py> BoundRef<'a, 'py, PyAny> {
pub unsafe fn ref_from_ptr(py: Python<'py>, ptr: &'a *mut ffi::PyObject) -> Self {
BoundRef(Bound::ref_from_ptr(py, ptr))
}

pub unsafe fn downcast_unchecked<T>(self) -> BoundRef<'a, 'py, T> {
BoundRef(self.0.downcast_unchecked::<T>())
}
}

// GIL Ref implementations for &'a T ran into trouble with orphan rules,
// so explicit implementations are used instead for the two relevant types.
impl<'a> From<BoundRef<'a, 'a, PyType>> for &'a PyType {
#[inline]
fn from(bound: BoundRef<'a, 'a, PyType>) -> Self {
bound.0.as_gil_ref()
}
}

impl<'a> From<BoundRef<'a, 'a, PyModule>> for &'a PyModule {
#[inline]
fn from(bound: BoundRef<'a, 'a, PyModule>) -> Self {
bound.0.as_gil_ref()
}
}

impl<'a, 'py, T> From<BoundRef<'a, 'py, T>> for &'a Bound<'py, T> {
#[inline]
fn from(bound: BoundRef<'a, 'py, T>) -> Self {
bound.0
}
}

impl<T> From<BoundRef<'_, '_, T>> for Py<T> {
#[inline]
fn from(bound: BoundRef<'_, '_, T>) -> Self {
bound.0.clone().unbind()
}
}
18 changes: 18 additions & 0 deletions src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,24 @@ impl<'py> Bound<'py, PyAny> {
) -> PyResult<Self> {
Py::from_owned_ptr_or_err(py, ptr).map(|obj| Self(py, ManuallyDrop::new(obj)))
}

/// This slightly strange method is used to obtain `&Bound<PyAny>` from a pointer in macro code
/// where we need to constrain the lifetime `'a` safely.
///
/// Note that `'py` is required to outlive `'a` implicitly by the nature of the fact that
/// `&'a Bound<'py>` means that `Bound<'py>` exists for at least the lifetime `'a`.
///
/// # Safety
/// - `ptr` must be a valid pointer to a Python object for the lifetime `'a`. The `ptr` can
/// be either a borrowed reference or an owned reference, it does not matter, as this is
/// just `&Bound` there will never be any ownership transfer.
#[inline]
pub(crate) unsafe fn ref_from_ptr<'a>(
_py: Python<'py>,
ptr: &'a *mut ffi::PyObject,
) -> &'a Self {
&*(ptr as *const *mut ffi::PyObject).cast::<Bound<'py, PyAny>>()
}
}

impl<'py, T> Bound<'py, T>
Expand Down
4 changes: 2 additions & 2 deletions src/tests/hygiene/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ impl Dummy {
#[staticmethod]
fn staticmethod() {}
#[classmethod]
fn clsmethod(_: &crate::types::PyType) {}
fn clsmethod(_: &crate::Bound<'_, crate::types::PyType>) {}
#[pyo3(signature = (*_args, **_kwds))]
fn __call__(
&self,
Expand Down Expand Up @@ -770,7 +770,7 @@ impl Dummy {
#[staticmethod]
fn staticmethod() {}
#[classmethod]
fn clsmethod(_: &crate::types::PyType) {}
fn clsmethod(_: &crate::Bound<'_, crate::types::PyType>) {}
#[pyo3(signature = (*_args, **_kwds))]
fn __call__(
&self,
Expand Down
16 changes: 14 additions & 2 deletions tests/test_class_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ fn panic_unsendable_child() {
test_unsendable::<UnsendableChild>().unwrap();
}

fn get_length(obj: &PyAny) -> PyResult<usize> {
fn get_length(obj: &Bound<'_, PyAny>) -> PyResult<usize> {
let length = obj.len()?;

Ok(length)
Expand All @@ -299,7 +299,18 @@ impl ClassWithFromPyWithMethods {
argument
}
#[classmethod]
fn classmethod(_cls: &PyType, #[pyo3(from_py_with = "PyAny::len")] argument: usize) -> usize {
fn classmethod(
_cls: &Bound<'_, PyType>,
#[pyo3(from_py_with = "Bound::<'_, PyAny>::len")] argument: usize,
) -> usize {
argument
}

#[classmethod]
fn classmethod_gil_ref(
_cls: &PyType,
#[pyo3(from_py_with = "PyAny::len")] argument: usize,
) -> usize {
argument
}

Expand All @@ -322,6 +333,7 @@ fn test_pymethods_from_py_with() {

assert instance.instance_method(arg) == 2
assert instance.classmethod(arg) == 2
assert instance.classmethod_gil_ref(arg) == 2
assert instance.staticmethod(arg) == 2
"#
);
Expand Down
20 changes: 15 additions & 5 deletions tests/test_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ impl ClassMethod {

#[classmethod]
/// Test class method.
fn method(cls: &PyType) -> PyResult<String> {
fn method(cls: &Bound<'_, PyType>) -> PyResult<String> {
Ok(format!("{}.method()!", cls.as_gil_ref().qualname()?))
}

#[classmethod]
/// Test class method.
fn method_gil_ref(cls: &PyType) -> PyResult<String> {
Ok(format!("{}.method()!", cls.qualname()?))
}

Expand Down Expand Up @@ -108,8 +114,12 @@ struct ClassMethodWithArgs {}
#[pymethods]
impl ClassMethodWithArgs {
#[classmethod]
fn method(cls: &PyType, input: &PyString) -> PyResult<String> {
Ok(format!("{}.method({})", cls.qualname()?, input))
fn method(cls: &Bound<'_, PyType>, input: &PyString) -> PyResult<String> {
Ok(format!(
"{}.method({})",
cls.as_gil_ref().qualname()?,
input
))
}
}

Expand Down Expand Up @@ -915,7 +925,7 @@ impl r#RawIdents {
}

#[classmethod]
pub fn r#class_method(_: &PyType, r#type: PyObject) -> PyObject {
pub fn r#class_method(_: &Bound<'_, PyType>, r#type: PyObject) -> PyObject {
r#type
}

Expand Down Expand Up @@ -1082,7 +1092,7 @@ issue_1506!(

#[classmethod]
fn issue_1506_class(
_cls: &PyType,
_cls: &Bound<'_, PyType>,
_py: Python<'_>,
_arg: &PyAny,
_args: &PyTuple,
Expand Down
Loading
Loading