Skip to content

Commit

Permalink
simplify thread checker implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Sep 3, 2023
1 parent 218a595 commit 4c46d81
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 74 deletions.
2 changes: 1 addition & 1 deletion guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,7 @@ impl pyo3::impl_::pyclass::PyClassImpl for MyClass {
const IS_SUBCLASS: bool = false;
type Layout = PyCell<MyClass>;
type BaseType = PyAny;
type ThreadChecker = pyo3::impl_::pyclass::ThreadCheckerStub<MyClass>;
type ThreadChecker = pyo3::impl_::pyclass::SendablePyClass<MyClass>;
type PyClassMutability = <<pyo3::PyAny as pyo3::impl_::pyclass::PyClassBaseType>::PyClassMutability as pyo3::impl_::pycell::PyClassMutability>::MutableChild;
type Dict = pyo3::impl_::pyclass::PyClassDummySlot;
type WeakRef = pyo3::impl_::pyclass::PyClassDummySlot;
Expand Down
8 changes: 2 additions & 6 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -947,13 +947,9 @@ impl<'a> PyClassImplsBuilder<'a> {
};

let thread_checker = if self.attr.options.unsendable.is_some() {
quote! { _pyo3::impl_::pyclass::ThreadCheckerImpl<#cls> }
} else if self.attr.options.extends.is_some() {
quote! {
_pyo3::impl_::pyclass::ThreadCheckerInherited<#cls, <#cls as _pyo3::impl_::pyclass::PyClassImpl>::BaseType>
}
quote! { _pyo3::impl_::pyclass::ThreadCheckerImpl }
} else {
quote! { _pyo3::impl_::pyclass::ThreadCheckerStub<#cls> }
quote! { _pyo3::impl_::pyclass::SendablePyClass<#cls> }
};

let (pymethods_items, inventory, inventory_class) = match self.methods_type {
Expand Down
67 changes: 22 additions & 45 deletions src/impl_/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1018,85 +1018,64 @@ pub trait PyClassThreadChecker<T>: Sized {
private_decl! {}
}

/// Stub checker for `Send` types.
/// Default thread checker for `#[pyclass]`.
///
/// Keeping the T: Send bound here slightly improves the compile
/// error message to hint to users to figure out what's wrong
/// when `#[pyclass]` types do not implement `Send`.
#[doc(hidden)]
pub struct ThreadCheckerStub<T: Send>(PhantomData<T>);
pub struct SendablePyClass<T: Send>(PhantomData<T>);

impl<T: Send> PyClassThreadChecker<T> for ThreadCheckerStub<T> {
impl<T: Send> PyClassThreadChecker<T> for SendablePyClass<T> {
fn ensure(&self) {}
fn can_drop(&self, _py: Python<'_>) -> bool {
true
}
#[inline]
fn new() -> Self {
ThreadCheckerStub(PhantomData)
SendablePyClass(PhantomData)
}
private_impl! {}
}

impl<T: PyNativeType> PyClassThreadChecker<T> for ThreadCheckerStub<crate::PyObject> {
fn ensure(&self) {}
fn can_drop(&self, _py: Python<'_>) -> bool {
true
}
#[inline]
fn new() -> Self {
ThreadCheckerStub(PhantomData)
}
private_impl! {}
}

/// Thread checker for unsendable types.
/// Thread checker for `#[pyclass(unsendable)]` types.
/// Panics when the value is accessed by another thread.
#[doc(hidden)]
pub struct ThreadCheckerImpl<T>(thread::ThreadId, PhantomData<T>);
pub struct ThreadCheckerImpl(thread::ThreadId);

impl<T> PyClassThreadChecker<T> for ThreadCheckerImpl<T> {
fn ensure(&self) {
impl ThreadCheckerImpl {
fn ensure(&self, type_name: &'static str) {
assert_eq!(
thread::current().id(),
self.0,
"{} is unsendable, but sent to another thread!",
std::any::type_name::<T>()
"{} is unsendable, but sent to another thread",
type_name
);
}
fn can_drop(&self, py: Python<'_>) -> bool {

fn can_drop(&self, py: Python<'_>, type_name: &'static str) -> bool {
if thread::current().id() != self.0 {
PyRuntimeError::new_err(format!(
"{} is unsendbale, but is dropped on another thread!",
std::any::type_name::<T>()
"{} is unsendable, but is being dropped on another thread",
type_name
))
.write_unraisable(py, None);
return false;
}

true
}
fn new() -> Self {
ThreadCheckerImpl(thread::current().id(), PhantomData)
}
private_impl! {}
}

/// Thread checker for types that have `Send` and `extends=...`.
/// Ensures that `T: Send` and the parent is not accessed by another thread.
#[doc(hidden)]
pub struct ThreadCheckerInherited<T: PyClass + Send, U: PyClassBaseType>(
PhantomData<T>,
U::ThreadChecker,
);

impl<T: PyClass + Send, U: PyClassBaseType> PyClassThreadChecker<T>
for ThreadCheckerInherited<T, U>
{
impl<T> PyClassThreadChecker<T> for ThreadCheckerImpl {
fn ensure(&self) {
self.1.ensure();
self.ensure(std::any::type_name::<T>());
}
fn can_drop(&self, py: Python<'_>) -> bool {
self.1.can_drop(py)
self.can_drop(py, std::any::type_name::<T>())
}
fn new() -> Self {
ThreadCheckerInherited(PhantomData, U::ThreadChecker::new())
ThreadCheckerImpl(thread::current().id())
}
private_impl! {}
}
Expand All @@ -1105,7 +1084,6 @@ impl<T: PyClass + Send, U: PyClassBaseType> PyClassThreadChecker<T>
pub trait PyClassBaseType: Sized {
type LayoutAsBase: PyCellLayout<Self>;
type BaseNativeType;
type ThreadChecker: PyClassThreadChecker<Self>;
type Initializer: PyObjectInit<Self>;
type PyClassMutability: PyClassMutability;
}
Expand All @@ -1116,7 +1094,6 @@ pub trait PyClassBaseType: Sized {
impl<T: PyClass> PyClassBaseType for T {
type LayoutAsBase = crate::pycell::PyCell<T>;
type BaseNativeType = T::BaseNativeType;
type ThreadChecker = T::ThreadChecker;
type Initializer = crate::pyclass_init::PyClassInitializer<Self>;
type PyClassMutability = T::PyClassMutability;
}
Expand Down
1 change: 0 additions & 1 deletion src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ macro_rules! pyobject_native_type_sized {
impl<$($generics,)*> $crate::impl_::pyclass::PyClassBaseType for $name {
type LayoutAsBase = $crate::pycell::PyCellBase<$layout>;
type BaseNativeType = $name;
type ThreadChecker = $crate::impl_::pyclass::ThreadCheckerStub<$crate::PyObject>;
type Initializer = $crate::pyclass_init::PyNativeTypeInitializer<Self>;
type PyClassMutability = $crate::pycell::impl_::ImmutableClass;
}
Expand Down
6 changes: 3 additions & 3 deletions tests/test_class_basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ fn test_unsendable<T: PyClass + 'static>() -> PyResult<()> {
#[test]
#[cfg_attr(target_arch = "wasm32", ignore)]
#[should_panic(
expected = "test_class_basics::UnsendableBase is unsendable, but sent to another thread!"
expected = "test_class_basics::UnsendableBase is unsendable, but sent to another thread"
)]
fn panic_unsendable_base() {
test_unsendable::<UnsendableBase>().unwrap();
Expand All @@ -277,7 +277,7 @@ fn panic_unsendable_base() {
#[test]
#[cfg_attr(target_arch = "wasm32", ignore)]
#[should_panic(
expected = "test_class_basics::UnsendableBase is unsendable, but sent to another thread!"
expected = "test_class_basics::UnsendableBase is unsendable, but sent to another thread"
)]
fn panic_unsendable_child() {
test_unsendable::<UnsendableChild>().unwrap();
Expand Down Expand Up @@ -584,7 +584,7 @@ fn drop_unsendable_elsewhere() {
assert!(!dropped.load(Ordering::SeqCst));

let (err, object) = capture.borrow_mut(py).capture.take().unwrap();
assert_eq!(err.to_string(), "RuntimeError: test_class_basics::drop_unsendable_elsewhere::Unsendable is unsendbale, but is dropped on another thread!");
assert_eq!(err.to_string(), "RuntimeError: test_class_basics::drop_unsendable_elsewhere::Unsendable is unsendable, but is being dropped on another thread");
assert!(object.is_none(py));

capture.borrow_mut(py).uninstall(py);
Expand Down
15 changes: 0 additions & 15 deletions tests/ui/abi3_nativetype_inheritance.stderr
Original file line number Diff line number Diff line change
@@ -1,18 +1,3 @@
error[E0277]: the trait bound `PyDict: PyClass` is not satisfied
--> tests/ui/abi3_nativetype_inheritance.rs:5:1
|
5 | #[pyclass(extends=PyDict)]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `PyClass` is not implemented for `PyDict`
|
= help: the trait `PyClass` is implemented for `TestClass`
= note: required for `PyDict` to implement `PyClassBaseType`
note: required by a bound in `ThreadCheckerInherited`
--> src/impl_/pyclass.rs
|
| pub struct ThreadCheckerInherited<T: PyClass + Send, U: PyClassBaseType>(
| ^^^^^^^^^^^^^^^ required by this bound in `ThreadCheckerInherited`
= note: this error originates in the attribute macro `pyclass` (in Nightly builds, run with -Z macro-backtrace for more info)

error[E0277]: the trait bound `PyDict: PyClass` is not satisfied
--> tests/ui/abi3_nativetype_inheritance.rs:5:1
|
Expand Down
6 changes: 3 additions & 3 deletions tests/ui/pyclass_send.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ note: required because it appears within the type `NotThreadSafe`
|
5 | struct NotThreadSafe {
| ^^^^^^^^^^^^^
note: required by a bound in `ThreadCheckerStub`
note: required by a bound in `SendablePyClass`
--> src/impl_/pyclass.rs
|
| pub struct ThreadCheckerStub<T: Send>(PhantomData<T>);
| ^^^^ required by this bound in `ThreadCheckerStub`
| pub struct SendablePyClass<T: Send>(PhantomData<T>);
| ^^^^ required by this bound in `SendablePyClass`
= note: this error originates in the attribute macro `pyclass` (in Nightly builds, run with -Z macro-backtrace for more info)

0 comments on commit 4c46d81

Please sign in to comment.