Skip to content

Commit

Permalink
make 'cargo test' pass (modulo test_getter_setter)
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoldbaum committed Oct 29, 2024
1 parent cd443eb commit a1437c3
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 22 deletions.
8 changes: 5 additions & 3 deletions guide/src/class/protocols.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,20 @@ Example:
```rust
use pyo3::prelude::*;

use std::sync::Mutex;

#[pyclass]
struct MyIterator {
iter: Box<dyn Iterator<Item = PyObject> + Send>,
iter: Mutex<Box<dyn Iterator<Item = PyObject> + Send>>,
}

#[pymethods]
impl MyIterator {
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<PyObject> {
slf.iter.next()
fn __next__(slf: PyRefMut<'_, Self>) -> Option<PyObject> {
slf.iter.lock().unwrap().next()
}
}
```
Expand Down
2 changes: 1 addition & 1 deletion guide/src/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -1797,7 +1797,7 @@ There can be two fixes:
```

After:
```rust
```rust,ignore
# #![allow(dead_code)]
use pyo3::prelude::*;
use std::sync::{Arc, Mutex};
Expand Down
32 changes: 18 additions & 14 deletions tests/test_gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use pyo3::prelude::*;
use pyo3::py_run;
use std::cell::Cell;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::sync::Once;
use std::sync::{Arc, Mutex};

#[path = "../src/tests/common.rs"]
mod common;
Expand Down Expand Up @@ -403,28 +403,31 @@ fn tries_gil_in_traverse() {
fn traverse_cannot_be_hijacked() {
#[pyclass]
struct HijackedTraverse {
traversed: Cell<bool>,
hijacked: Cell<bool>,
traversed: AtomicBool,
hijacked: AtomicBool,
}

impl HijackedTraverse {
fn new() -> Self {
Self {
traversed: Cell::new(false),
hijacked: Cell::new(false),
traversed: AtomicBool::new(false),
hijacked: AtomicBool::new(false),
}
}

fn traversed_and_hijacked(&self) -> (bool, bool) {
(self.traversed.get(), self.hijacked.get())
(
self.traversed.load(Ordering::Acquire),
self.hijacked.load(Ordering::Acquire),
)
}
}

#[pymethods]
impl HijackedTraverse {
#[allow(clippy::unnecessary_wraps)]
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.traversed.set(true);
self.traversed.store(true, Ordering::Release);
Ok(())
}
}
Expand All @@ -436,7 +439,7 @@ fn traverse_cannot_be_hijacked() {

impl Traversable for PyRef<'_, HijackedTraverse> {
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.hijacked.set(true);
self.hijacked.store(true, Ordering::Release);
Ok(())
}
}
Expand All @@ -455,15 +458,16 @@ fn traverse_cannot_be_hijacked() {

#[pyclass]
struct DropDuringTraversal {
cycle: Cell<Option<Py<Self>>>,
cycle: Mutex<Option<Py<Self>>>,
_guard: DropGuard,
}

#[pymethods]
impl DropDuringTraversal {
#[allow(clippy::unnecessary_wraps)]
fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.cycle.take();
let mut cycle_ref = self.cycle.lock().unwrap();
*cycle_ref = None;
Ok(())
}
}
Expand All @@ -474,7 +478,7 @@ fn drop_during_traversal_with_gil() {
let (guard, check) = drop_check();

let ptr = Python::with_gil(|py| {
let cycle = Cell::new(None);
let cycle = Mutex::new(None);
let inst = Py::new(
py,
DropDuringTraversal {
Expand All @@ -484,7 +488,7 @@ fn drop_during_traversal_with_gil() {
)
.unwrap();

inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py)));
*inst.borrow_mut(py).cycle.lock().unwrap() = Some(inst.clone_ref(py));

check.assert_not_dropped();
let ptr = inst.as_ptr();
Expand All @@ -508,7 +512,7 @@ fn drop_during_traversal_without_gil() {
let (guard, check) = drop_check();

let inst = Python::with_gil(|py| {
let cycle = Cell::new(None);
let cycle = Mutex::new(None);
let inst = Py::new(
py,
DropDuringTraversal {
Expand All @@ -518,7 +522,7 @@ fn drop_during_traversal_without_gil() {
)
.unwrap();

inst.borrow_mut(py).cycle.set(Some(inst.clone_ref(py)));
*inst.borrow_mut(py).cycle.lock().unwrap() = Some(inst.clone_ref(py));

check.assert_not_dropped();
inst
Expand Down
9 changes: 5 additions & 4 deletions tests/test_proto_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::exceptions::{PyAttributeError, PyIndexError, PyValueError};
use pyo3::types::{PyDict, PyList, PyMapping, PySequence, PySlice, PyType};
use pyo3::{prelude::*, py_run};
use std::iter;
use std::sync::Mutex;

#[path = "../src/tests/common.rs"]
mod common;
Expand Down Expand Up @@ -361,7 +362,7 @@ fn sequence() {

#[pyclass]
struct Iterator {
iter: Box<dyn iter::Iterator<Item = i32> + Send>,
iter: Mutex<Box<dyn iter::Iterator<Item = i32> + Send>>,
}

#[pymethods]
Expand All @@ -370,8 +371,8 @@ impl Iterator {
slf
}

fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<i32> {
slf.iter.next()
fn __next__(slf: PyRefMut<'_, Self>) -> Option<i32> {
slf.iter.lock().unwrap().next()
}
}

Expand All @@ -381,7 +382,7 @@ fn iterator() {
let inst = Py::new(
py,
Iterator {
iter: Box::new(5..8),
iter: Mutex::new(Box::new(5..8)),
},
)
.unwrap();
Expand Down

0 comments on commit a1437c3

Please sign in to comment.