Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Use a capsule-based API with a stable ABI for global, cross-extension functionality. #3073

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions src/global_api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//! TODO

use std::{ffi::CString, mem::forget};

use crate::{
conversion::PyTryInto,
exceptions::PyTypeError,
ffi,
sync::GILOnceCell,
type_object::PyTypeInfo,
types::{PyCapsule, PyDict, PyModule},
Py, PyResult, Python,
};

#[repr(C)]
pub(crate) struct GlobalApi {
version: u64,
pub(crate) create_panic_exception:
unsafe extern "C" fn(msg_ptr: *const u8, msg_len: usize) -> *mut ffi::PyObject,
}

pub(crate) fn ensure_global_api(py: Python<'_>) -> PyResult<&GlobalApi> {
let api = GLOBAL_API.0.get_or_try_init(py, || init_global_api(py))?;

// SAFETY: We inserted the capsule if it was missing
// and verified that it contains a compatible version.
Ok(unsafe { &**api })
}

struct GlobalApiPtr(GILOnceCell<*const GlobalApi>);

unsafe impl Send for GlobalApiPtr {}

unsafe impl Sync for GlobalApiPtr {}

static GLOBAL_API: GlobalApiPtr = GlobalApiPtr(GILOnceCell::new());

#[cold]
fn init_global_api(py: Python<'_>) -> PyResult<*const GlobalApi> {
let module = match PyModule::import(py, "pyo3") {
Ok(module) => module,
Err(_err) => {
let module = PyModule::new(py, "pyo3")?;

module.add(
"PanicException",
crate::panic::PanicException::type_object(py),
)?;

let sys = PyModule::import(py, "sys")?;
let modules: &PyDict = sys.getattr("modules")?.downcast()?;
modules.set_item("pyo3", module)?;

module
}
};

let capsule: &PyCapsule = match module.getattr("_GLOBAL_API") {
Ok(capsule) => PyTryInto::try_into(capsule)?,
Err(_err) => {
let api = GlobalApi {
version: 1,
create_panic_exception: crate::panic::create_panic_exception,
};

let capsule = PyCapsule::new(py, api, Some(CString::new("_GLOBAL_API").unwrap()))?;
module.setattr("_GLOBAL_API", capsule)?;
capsule
}
};

// SAFETY: All versions of the global API start with a version field.
let version = unsafe { *(capsule.pointer() as *mut u64) };
if version < 1 {
return Err(PyTypeError::new_err(format!(
"Version {} of global API is not supported by this version of PyO3",
version
)));
}

// Intentionally leak a reference to the capsule so we can safely cache a pointer into its interior.
forget(Py::<PyCapsule>::from(capsule));

Ok(capsule.pointer() as *const GlobalApi)
}
4 changes: 2 additions & 2 deletions src/impl_/trampoline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ where
let py_err = match panic_result {
Ok(Ok(value)) => return value,
Ok(Err(py_err)) => py_err,
Err(payload) => PanicException::from_panic_payload(payload),
Err(payload) => PanicException::from_panic_payload(py, payload),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The types do not work out at all for now, but I wanted to discuss the concrete approach first before trying plough through the trampolines.

};
py_err.restore(py);
R::ERR_VALUE
Expand All @@ -245,7 +245,7 @@ where
let pool = GILPool::new();
let py = pool.python();
if let Err(py_err) = panic::catch_unwind(move || body(py))
.unwrap_or_else(|payload| Err(PanicException::from_panic_payload(payload)))
.unwrap_or_else(|payload| Err(PanicException::from_panic_payload(py, payload)))
{
py_err.write_unraisable(py, py.from_borrowed_ptr_or_opt(ctx));
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ pub mod marker;
pub mod marshal;
#[macro_use]
pub mod sync;
mod global_api;
pub mod panic;
pub mod prelude;
pub mod pycell;
Expand Down
43 changes: 36 additions & 7 deletions src/panic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
//! Helper to convert Rust panics to Python exceptions.
use crate::conversion::{FromPyPointer, IntoPyPointer};
use crate::exceptions::PyBaseException;
use crate::PyErr;
use crate::ffi;
use crate::global_api::ensure_global_api;
use crate::{PyAny, Python};
use std::any::Any;
use std::slice;
use std::str;

pyo3_exception!(
"
Expand All @@ -20,13 +25,37 @@ impl PanicException {
///
/// Attempts to format the error in the same way panic does.
#[cold]
pub(crate) fn from_panic_payload(payload: Box<dyn Any + Send + 'static>) -> PyErr {
if let Some(string) = payload.downcast_ref::<String>() {
Self::new_err((string.clone(),))
pub(crate) fn from_panic_payload<'py>(
py: Python<'py>,
payload: Box<dyn Any + Send + 'static>,
) -> &'py PyAny {
let msg = if let Some(string) = payload.downcast_ref::<String>() {
string.clone()
} else if let Some(s) = payload.downcast_ref::<&str>() {
Self::new_err((s.to_string(),))
s.to_string()
} else {
Self::new_err(("panic from Rust code",))
}
"panic from Rust code".to_owned()
};

let api = match ensure_global_api(py) {
Ok(api) => api,
// The global API is unavailable, hence we fall back to our own `PanicException`.
Err(err) => return PanicException::new_err((msg,)).into_value(py).into_ref(py),
};

let err = (api.create_panic_exception)(msg.as_ptr(), msg.len());

PyAny::from_owned_ptr(py, err)
}
}

pub(crate) unsafe extern "C" fn create_panic_exception(
msg_ptr: *const u8,
msg_len: usize,
) -> *mut ffi::PyObject {
let msg = str::from_utf8_unchecked(slice::from_raw_parts(msg_ptr, msg_len));

let err = PanicException::new_err((msg,));

err.into_value(Python::assume_gil_acquired()).into_ptr()
}