Skip to content

Commit

Permalink
Merge pull request #856 from althonos/patch-iter
Browse files Browse the repository at this point in the history
Allow passing non-mutable references to self in PyIterProtocol
  • Loading branch information
kngwyu authored Apr 19, 2020
2 parents c897155 + 6ac1b05 commit a58a1cf
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 16 deletions.
7 changes: 5 additions & 2 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,10 @@ It includes two methods `__iter__` and `__next__`:
* `fn __iter__(slf: PyRefMut<Self>) -> PyResult<impl IntoPy<PyObject>>`
* `fn __next__(slf: PyRefMut<Self>) -> PyResult<Option<impl IntoPy<PyObject>>>`

Returning `Ok(None)` from `__next__` indicates that that there are no further items.
Returning `Ok(None)` from `__next__` indicates that that there are no further items.
These two methods can be take either `PyRef<Self>` or `PyRefMut<Self>` as their
first argument, so that mutable borrow can be avoided if needed.


Example:

Expand All @@ -823,7 +826,7 @@ struct MyIterator {

#[pyproto]
impl PyIterProtocol for MyIterator {
fn __iter__(mut slf: PyRefMut<Self>) -> PyResult<Py<MyIterator>> {
fn __iter__(slf: PyRef<Self>) -> PyResult<Py<MyIterator>> {
Ok(slf.into())
}
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<PyObject>> {
Expand Down
6 changes: 4 additions & 2 deletions pyo3-derive-backend/src/defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,15 @@ pub const ITER: Proto = Proto {
name: "Iter",
py_methods: &[],
methods: &[
MethodProto::Unary {
MethodProto::UnaryS {
name: "__iter__",
arg: "Receiver",
pyres: true,
proto: "pyo3::class::iter::PyIterIterProtocol",
},
MethodProto::Unary {
MethodProto::UnaryS {
name: "__next__",
arg: "Receiver",
pyres: true,
proto: "pyo3::class::iter::PyIterNextProtocol",
},
Expand Down
59 changes: 59 additions & 0 deletions pyo3-derive-backend/src/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ pub enum MethodProto {
pyres: bool,
proto: &'static str,
},
UnaryS {
name: &'static str,
arg: &'static str,
pyres: bool,
proto: &'static str,
},
Binary {
name: &'static str,
arg: &'static str,
Expand Down Expand Up @@ -60,6 +66,7 @@ impl MethodProto {
match *self {
MethodProto::Free { ref name, .. } => name,
MethodProto::Unary { ref name, .. } => name,
MethodProto::UnaryS { ref name, .. } => name,
MethodProto::Binary { ref name, .. } => name,
MethodProto::BinaryS { ref name, .. } => name,
MethodProto::Ternary { ref name, .. } => name,
Expand Down Expand Up @@ -114,6 +121,58 @@ pub(crate) fn impl_method_proto(
}
}
}
MethodProto::UnaryS {
pyres, proto, arg, ..
} => {
let p: syn::Path = syn::parse_str(proto).unwrap();
let (ty, succ) = get_res_success(ty);

let slf_name = syn::Ident::new(arg, Span::call_site());
let mut slf_ty = get_arg_ty(sig, 0);

// update the type if no lifetime was given:
// PyRef<Self> --> PyRef<'p, Self>
if let syn::Type::Path(ref mut path) = slf_ty {
if let syn::PathArguments::AngleBracketed(ref mut args) =
path.path.segments[0].arguments
{
if let syn::GenericArgument::Lifetime(_) = args.args[0] {
} else {
let lt = syn::parse_quote! {'p};
args.args.insert(0, lt);
}
}
}

let tmp: syn::ItemFn = syn::parse_quote! {
fn test(&self) -> <#cls as #p<'p>>::Result {}
};
sig.output = tmp.sig.output;
modify_self_ty(sig);

if let syn::FnArg::Typed(ref mut arg) = sig.inputs[0] {
arg.ty = Box::new(syn::parse_quote! {
<#cls as #p<'p>>::#slf_name
});
}

if pyres {
quote! {
impl<'p> #p<'p> for #cls {
type #slf_name = #slf_ty;
type Success = #succ;
type Result = #ty;
}
}
} else {
quote! {
impl<'p> #p<'p> for #cls {
type #slf_name = #slf_ty;
type Result = #ty;
}
}
}
}
MethodProto::Binary {
name,
arg,
Expand Down
13 changes: 8 additions & 5 deletions src/class/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,24 @@
//! Trait and support implementation for implementing iterators
use crate::callback::IntoPyCallbackOutput;
use crate::derive_utils::TryFromPyCell;
use crate::err::PyResult;
use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, PyRefMut, Python};
use crate::{ffi, IntoPy, IntoPyPointer, PyClass, PyObject, Python};

/// Python Iterator Interface.
///
/// Check [CPython doc](https://docs.python.org/3/c-api/typeobj.html#c.PyTypeObject.tp_iter)
/// for more.
#[allow(unused_variables)]
pub trait PyIterProtocol<'p>: PyClass {
fn __iter__(slf: PyRefMut<Self>) -> Self::Result
fn __iter__(slf: Self::Receiver) -> Self::Result
where
Self: PyIterIterProtocol<'p>,
{
unimplemented!()
}

fn __next__(slf: PyRefMut<Self>) -> Self::Result
fn __next__(slf: Self::Receiver) -> Self::Result
where
Self: PyIterNextProtocol<'p>,
{
Expand All @@ -28,11 +29,13 @@ pub trait PyIterProtocol<'p>: PyClass {
}

pub trait PyIterIterProtocol<'p>: PyIterProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Success: crate::IntoPy<PyObject>;
type Result: Into<PyResult<Self::Success>>;
}

pub trait PyIterNextProtocol<'p>: PyIterProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Success: crate::IntoPy<PyObject>;
type Result: Into<PyResult<Option<Self::Success>>>;
}
Expand Down Expand Up @@ -76,7 +79,7 @@ where
{
#[inline]
fn tp_iter() -> Option<ffi::getiterfunc> {
py_unary_refmut_func!(PyIterIterProtocol, T::__iter__)
py_unarys_func!(PyIterIterProtocol, T::__iter__)
}
}

Expand All @@ -99,7 +102,7 @@ where
{
#[inline]
fn tp_iternext() -> Option<ffi::iternextfunc> {
py_unary_refmut_func!(PyIterNextProtocol, T::__next__, IterNextConverter)
py_unarys_func!(PyIterNextProtocol, T::__next__, IterNextConverter)
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/class/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ macro_rules! py_unary_func {

#[macro_export]
#[doc(hidden)]
macro_rules! py_unary_refmut_func {
macro_rules! py_unarys_func {
($trait:ident, $class:ident :: $f:ident $(, $conv:expr)?) => {{
unsafe extern "C" fn wrap<T>(slf: *mut $crate::ffi::PyObject) -> *mut $crate::ffi::PyObject
where
Expand All @@ -38,7 +38,9 @@ macro_rules! py_unary_refmut_func {
let py = pool.python();
$crate::run_callback(py, || {
let slf = py.from_borrowed_ptr::<$crate::PyCell<T>>(slf);
let res = $class::$f(slf.borrow_mut()).into();
let borrow = <T::Receiver>::try_from_pycell(slf)
.map_err(|e| e.into())?;
let res = $class::$f(borrow).into();
$crate::callback::convert(py, res $(.map($conv))?)
})
}
Expand Down
26 changes: 24 additions & 2 deletions src/derive_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

//! Functionality for the code generated by the derive backend
use crate::err::PyResult;
use crate::err::{PyErr, PyResult};
use crate::exceptions::TypeError;
use crate::instance::PyNativeType;
use crate::pyclass::PyClass;
use crate::pyclass_init::PyClassInitializer;
use crate::types::{PyAny, PyDict, PyModule, PyTuple};
use crate::{ffi, GILPool, IntoPy, PyObject, Python};
use crate::{ffi, GILPool, IntoPy, PyCell, PyObject, Python};
use std::cell::UnsafeCell;

/// Description of a python parameter; used for `parse_args()`.
Expand Down Expand Up @@ -243,3 +243,25 @@ where
{
type Target = T;
}

/// A trait for types that can be borrowed from a cell.
///
/// This serves to unify the use of `PyRef` and `PyRefMut` in automatically
/// derived code, since both types can be obtained from a `PyCell`.
#[doc(hidden)]
pub trait TryFromPyCell<'a, T: PyClass>: Sized {
type Error: Into<PyErr>;
fn try_from_pycell(cell: &'a crate::PyCell<T>) -> Result<Self, Self::Error>;
}

impl<'a, T, R> TryFromPyCell<'a, T> for R
where
T: 'a + PyClass,
R: std::convert::TryFrom<&'a PyCell<T>>,
R::Error: Into<PyErr>,
{
type Error = R::Error;
fn try_from_pycell(cell: &'a crate::PyCell<T>) -> Result<Self, Self::Error> {
<R as std::convert::TryFrom<&'a PyCell<T>>>::try_from(cell)
}
}
4 changes: 2 additions & 2 deletions tests/test_dunder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ struct Iterator {

#[pyproto]
impl<'p> PyIterProtocol for Iterator {
fn __iter__(slf: PyRefMut<Self>) -> PyResult<Py<Iterator>> {
fn __iter__(slf: PyRef<'p, Self>) -> PyResult<Py<Iterator>> {
Ok(slf.into())
}

fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<i32>> {
fn __next__(mut slf: PyRefMut<'p, Self>) -> PyResult<Option<i32>> {
Ok(slf.iter.next())
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pyself.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct Iter {

#[pyproto]
impl PyIterProtocol for Iter {
fn __iter__(slf: PyRefMut<Self>) -> PyResult<PyObject> {
fn __iter__(slf: PyRef<Self>) -> PyResult<PyObject> {
let py = unsafe { Python::assume_gil_acquired() };
Ok(slf.into_py(py))
}
Expand Down

0 comments on commit a58a1cf

Please sign in to comment.