From 871657e61f05305c37b9703a7d484c04b1b647a7 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Tue, 28 Mar 2023 18:29:49 +0200 Subject: [PATCH] Use a capsule-based API with a stable ABI for global, cross-extension functionality. --- src/global_api.rs | 83 +++++++++++++++++++++++++++++++++++++++++ src/impl_/trampoline.rs | 4 +- src/lib.rs | 1 + src/panic.rs | 43 +++++++++++++++++---- 4 files changed, 122 insertions(+), 9 deletions(-) create mode 100644 src/global_api.rs diff --git a/src/global_api.rs b/src/global_api.rs new file mode 100644 index 00000000000..9076c626e3f --- /dev/null +++ b/src/global_api.rs @@ -0,0 +1,83 @@ +//! TODO + +use std::{ffi::CString, mem::forget}; + +use crate::{ + conversion::PyTryInto, + exceptions::PyTypeError, + ffi, + sync::GILOnceCell, + type_object::PyTypeInfo, + types::{PyCapsule, PyModule}, + Py, PyResult, Python, +}; + +#[repr(C)] +pub(crate) struct GlobalApi { + version: u64, + pub(crate) create_panic_exception: + unsafe extern "C" fn(msg_ptr: *const u8, msg_len: usize) -> *mut ffi::PyObject, +} + +pub(crate) fn ensure_global_api(py: Python<'_>) -> PyResult<&GlobalApi> { + let api = GLOBAL_API.0.get_or_try_init(py, || init_global_api(py))?; + + // SAFETY: We inserted the capsule if it was missing + // and verified that it contains a compatible version. + Ok(unsafe { &**api }) +} + +struct GlobalApiPtr(GILOnceCell<*const GlobalApi>); + +unsafe impl Send for GlobalApiPtr {} + +unsafe impl Sync for GlobalApiPtr {} + +static GLOBAL_API: GlobalApiPtr = GlobalApiPtr(GILOnceCell::new()); + +#[cold] +fn init_global_api(py: Python<'_>) -> PyResult<*const GlobalApi> { + let module = match PyModule::import(py, "pyo3") { + Ok(module) => module, + Err(_err) => { + let module = PyModule::new(py, "pyo3")?; + + // TODO: Register module so it become globally visible and hence importable... + + module.add( + "PanicException", + crate::panic::PanicException::type_object(py), + )?; + + module + } + }; + + let capsule: &PyCapsule = match module.getattr("_GLOBAL_API") { + Ok(capsule) => PyTryInto::try_into(capsule)?, + Err(_err) => { + let api = GlobalApi { + version: 1, + create_panic_exception: crate::panic::create_panic_exception, + }; + + let capsule = PyCapsule::new(py, api, Some(CString::new("_GLOBAL_API").unwrap()))?; + module.setattr("_GLOBAL_API", capsule)?; + capsule + } + }; + + // SAFETY: All versions of the global API start with a version field. + let version = unsafe { *(capsule.pointer() as *mut u64) }; + if version < 1 { + return Err(PyTypeError::new_err(format!( + "Version {} of global API is not supported by this version of PyO3", + version + ))); + } + + // Intentionally leak a reference to the capsule so we can safely cache a pointer into its interior. + forget(Py::::from(capsule)); + + Ok(capsule.pointer() as *const GlobalApi) +} diff --git a/src/impl_/trampoline.rs b/src/impl_/trampoline.rs index c7bea9abe2e..deea335fd82 100644 --- a/src/impl_/trampoline.rs +++ b/src/impl_/trampoline.rs @@ -220,7 +220,7 @@ where let py_err = match panic_result { Ok(Ok(value)) => return value, Ok(Err(py_err)) => py_err, - Err(payload) => PanicException::from_panic_payload(payload), + Err(payload) => PanicException::from_panic_payload(py, payload), }; py_err.restore(py); R::ERR_VALUE @@ -245,7 +245,7 @@ where let pool = GILPool::new(); let py = pool.python(); if let Err(py_err) = panic::catch_unwind(move || body(py)) - .unwrap_or_else(|payload| Err(PanicException::from_panic_payload(payload))) + .unwrap_or_else(|payload| Err(PanicException::from_panic_payload(py, payload))) { py_err.write_unraisable(py, py.from_borrowed_ptr_or_opt(ctx)); } diff --git a/src/lib.rs b/src/lib.rs index ae37d65b1f2..77ad820e8c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -410,6 +410,7 @@ pub mod marker; pub mod marshal; #[macro_use] pub mod sync; +mod global_api; pub mod panic; pub mod prelude; pub mod pycell; diff --git a/src/panic.rs b/src/panic.rs index 50489d50aeb..cf164d5d300 100644 --- a/src/panic.rs +++ b/src/panic.rs @@ -1,7 +1,12 @@ //! Helper to convert Rust panics to Python exceptions. +use crate::conversion::{FromPyPointer, IntoPyPointer}; use crate::exceptions::PyBaseException; -use crate::PyErr; +use crate::ffi; +use crate::global_api::ensure_global_api; +use crate::{PyAny, Python}; use std::any::Any; +use std::slice; +use std::str; pyo3_exception!( " @@ -20,13 +25,37 @@ impl PanicException { /// /// Attempts to format the error in the same way panic does. #[cold] - pub(crate) fn from_panic_payload(payload: Box) -> PyErr { - if let Some(string) = payload.downcast_ref::() { - Self::new_err((string.clone(),)) + pub(crate) fn from_panic_payload<'py>( + py: Python<'py>, + payload: Box, + ) -> &'py PyAny { + let msg = if let Some(string) = payload.downcast_ref::() { + string.clone() } else if let Some(s) = payload.downcast_ref::<&str>() { - Self::new_err((s.to_string(),)) + s.to_string() } else { - Self::new_err(("panic from Rust code",)) - } + "panic from Rust code".to_owned() + }; + + let api = match ensure_global_api(py) { + Ok(api) => api, + // The global API is unavailable, hence we fall back to our own `PanicException`. + Err(err) => return PanicException::new_err((msg,)).into_value(py).into_ref(py), + }; + + let err = (api.create_panic_exception)(msg.as_ptr(), msg.len()); + + PyAny::from_owned_ptr(py, err) } } + +pub(crate) unsafe extern "C" fn create_panic_exception( + msg_ptr: *const u8, + msg_len: usize, +) -> *mut ffi::PyObject { + let msg = str::from_utf8_unchecked(slice::from_raw_parts(msg_ptr, msg_len)); + + let err = PanicException::new_err((msg,)); + + err.into_value(Python::assume_gil_acquired()).into_ptr() +}