diff --git a/src/type_object.rs b/src/type_object.rs index f37bb5fc4bb..15b525fd60c 100644 --- a/src/type_object.rs +++ b/src/type_object.rs @@ -256,7 +256,7 @@ where let gil = Python::acquire_gil(); let py = gil.python(); - initialize_type::(py).unwrap_or_else(|_| { + initialize_type::(py, None).unwrap_or_else(|_| { panic!("An error occurred while initializing class {}", Self::NAME) }); } @@ -290,21 +290,15 @@ pub trait PyTypeCreate: PyObjectAlloc + PyTypeObject + Sized { impl PyTypeCreate for T where T: PyObjectAlloc + PyTypeObject + Sized {} /// Register new type in python object system. -/// -/// Currently, module_name is always None, so it defaults to pyo3_extension #[cfg(not(Py_LIMITED_API))] -pub fn initialize_type(py: Python) -> PyResult<*mut ffi::PyTypeObject> +pub fn initialize_type(py: Python, module_name: Option<&str>) -> PyResult<*mut ffi::PyTypeObject> where T: PyObjectAlloc + PyTypeInfo + PyMethodsProtocol, { - let type_name = CString::new(T::NAME).expect("class name must not contain NUL byte"); - let type_object: &mut ffi::PyTypeObject = unsafe { T::type_object() }; let base_type_object: &mut ffi::PyTypeObject = unsafe { ::type_object() }; - type_object.tp_name = type_name.into_raw(); - // PyPy will segfault if passed only a nul terminator as `tp_doc`. // ptr::null() is OK though. if T::DESCRIPTION == "\0" { @@ -315,6 +309,13 @@ where type_object.tp_base = base_type_object; + let name = match module_name { + Some(module_name) => format!("{}.{}", module_name, T::NAME), + None => T::NAME.to_string(), + }; + let name = CString::new(name).expect("Module name/type name must not contain NUL byte"); + type_object.tp_name = name.into_raw(); + // dealloc type_object.tp_dealloc = Some(tp_dealloc_callback::); diff --git a/tests/test_class_basics.rs b/tests/test_class_basics.rs index 3b91e429e7e..7a4959c813b 100644 --- a/tests/test_class_basics.rs +++ b/tests/test_class_basics.rs @@ -1,4 +1,5 @@ use pyo3::prelude::*; +use pyo3::type_object::initialize_type; #[macro_use] mod common; @@ -71,4 +72,9 @@ fn empty_class_in_module() { // We currently have no way of determining a canonical module, so builtins is better // than using whatever calls init first. assert_eq!(module, "builtins"); + + // The module name can also be set manually by calling `initialize_type`. + initialize_type::(py, Some("test_module.nested")).unwrap(); + let module: String = ty.getattr("__module__").unwrap().extract().unwrap(); + assert_eq!(module, "test_module.nested"); } diff --git a/tests/test_various.rs b/tests/test_various.rs index 5af5792a055..08bb35e676f 100644 --- a/tests/test_various.rs +++ b/tests/test_various.rs @@ -1,6 +1,7 @@ use pyo3::prelude::*; +use pyo3::type_object::initialize_type; use pyo3::types::IntoPyDict; -use pyo3::types::PyTuple; +use pyo3::types::{PyDict, PyTuple}; use pyo3::wrap_pyfunction; use std::isize; @@ -117,3 +118,55 @@ fn pytuple_pyclass_iter() { py_assert!(py, tup, "type(tup[0]).__name__ == type(tup[0]).__name__"); py_assert!(py, tup, "tup[0] != tup[1]"); } + +#[pyclass(dict)] +struct PickleSupport {} + +#[pymethods] +impl PickleSupport { + #[new] + fn new(obj: &PyRawObject) { + obj.init({ PickleSupport {} }); + } + + pub fn __reduce__(slf: PyRef) -> PyResult<(PyObject, Py, PyObject)> { + let gil = Python::acquire_gil(); + let py = gil.python(); + let cls = slf.to_object(py).getattr(py, "__class__")?; + let dict = slf.to_object(py).getattr(py, "__dict__")?; + Ok((cls, PyTuple::empty(py), dict)) + } +} + +fn add_module(py: Python, module: &PyModule) -> PyResult<()> { + py.import("sys")? + .dict() + .get_item("modules") + .unwrap() + .downcast_mut::()? + .set_item(module.name()?, module) +} + +#[test] +fn test_pickle() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let module = PyModule::new(py, "test_module").unwrap(); + module.add_class::().unwrap(); + add_module(py, module).unwrap(); + initialize_type::(py, Some("test_module")).unwrap(); + let inst = PyRef::new(py, PickleSupport {}).unwrap(); + py_run!( + py, + inst, + r#" + inst.a = 1 + assert inst.__dict__ == {'a': 1} + + import pickle + inst2 = pickle.loads(pickle.dumps(inst)) + + assert inst2.__dict__ == {'a': 1} + "# + ); +}