diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index 6afb66ef309..a3d2f4ad0ed 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -475,10 +475,14 @@ fn impl_arg_param( let ty = arg.ty; let name = arg.name; + let transform_error = quote! { + |e| pyo3::derive_utils::argument_extraction_error(_py, stringify!(#name), e) + }; if spec.is_args(&name) { return quote! { - let #arg_name = <#ty as pyo3::FromPyObject>::extract(_args.as_ref())?; + let #arg_name = <#ty as pyo3::FromPyObject>::extract(_args.as_ref()) + .map_err(#transform_error)?; }; } else if spec.is_kwargs(&name) { return quote! { @@ -518,7 +522,7 @@ fn impl_arg_param( quote! { let #mut_ _tmp: #target_ty = match #arg_value { - Some(_obj) => _obj.extract()?, + Some(_obj) => _obj.extract().map_err(#transform_error)?, None => #default, }; let #arg_name = #borrow_tmp; @@ -526,7 +530,7 @@ fn impl_arg_param( } else { quote! { let #arg_name = match #arg_value { - Some(_obj) => _obj.extract()?, + Some(_obj) => _obj.extract().map_err(#transform_error)?, None => #default, }; } diff --git a/src/derive_utils.rs b/src/derive_utils.rs index 52e22fac2f5..4c2eee5ec86 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -8,7 +8,7 @@ use crate::err::{PyErr, PyResult}; use crate::exceptions::PyTypeError; use crate::instance::PyNativeType; use crate::pyclass::{PyClass, PyClassThreadChecker}; -use crate::types::{PyAny, PyDict, PyModule, PyTuple}; +use crate::types::{PyAny, PyDict, PyModule, PyString, PyTuple}; use crate::{ffi, GILPool, IntoPy, PyCell, Python}; use std::cell::UnsafeCell; @@ -111,6 +111,18 @@ pub fn parse_fn_args<'p>( Ok((args, kwargs)) } +/// Add the argument name to the error message of an error which occurred during argument extraction +pub fn argument_extraction_error(py: Python, arg_name: &str, original_error: PyErr) -> PyErr { + let reason = original_error + .instance(py) + .str() + .unwrap_or_else(|_| PyString::new(py, "")); + PyErr::from_type( + original_error.ptype(py), + format!("argument '{}': {}", arg_name, reason), + ) +} + /// `Sync` wrapper of `ffi::PyModuleDef`. #[doc(hidden)] pub struct ModuleDef(UnsafeCell); diff --git a/tests/test_pyfunction.rs b/tests/test_pyfunction.rs index 6f6bf44c077..affb768af65 100644 --- a/tests/test_pyfunction.rs +++ b/tests/test_pyfunction.rs @@ -119,3 +119,54 @@ fn test_raw_function() { .unwrap(); assert_eq!(res, "Some(true)"); } + +#[pyfunction] +fn conversion_error(str_arg: &str, int_arg: i64, tuple_arg: (&str, f64), option_arg: Option) { + println!( + "{:?} {:?} {:?} {:?}", + str_arg, int_arg, tuple_arg, option_arg + ); +} + +#[test] +fn test_conversion_error() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let conversion_error = wrap_pyfunction!(conversion_error)(py).unwrap(); + py_expect_exception!( + py, + conversion_error, + "conversion_error(None, None, None, None)", + PyTypeError, + "argument 'str_arg': Can't convert None to PyString" + ); + py_expect_exception!( + py, + conversion_error, + "conversion_error(100, None, None, None)", + PyTypeError, + "argument 'str_arg': Can't convert 100 to PyString" + ); + py_expect_exception!( + py, + conversion_error, + "conversion_error('string1', 'string2', None, None)", + PyTypeError, + "argument 'int_arg': 'str' object cannot be interpreted as an integer" + ); + py_expect_exception!( + py, + conversion_error, + "conversion_error('string1', -100, 'string2', None)", + PyTypeError, + "argument 'tuple_arg': Can't convert 'string2' to PyTuple" + ); + py_expect_exception!( + py, + conversion_error, + "conversion_error('string1', -100, ('string2', 10.), 'string3')", + PyTypeError, + "argument 'option_arg': 'str' object cannot be interpreted as an integer" + ); +} diff --git a/tests/test_string.rs b/tests/test_string.rs index 38d375b5418..4e1ea4e1b07 100644 --- a/tests/test_string.rs +++ b/tests/test_string.rs @@ -1,5 +1,4 @@ use pyo3::prelude::*; -use pyo3::py_run; use pyo3::wrap_pyfunction; mod common; @@ -15,15 +14,11 @@ fn test_unicode_encode_error() { let py = gil.python(); let take_str = wrap_pyfunction!(take_str)(py).unwrap(); - py_run!( + py_expect_exception!( py, take_str, - r#" - try: - take_str('\ud800') - except UnicodeEncodeError as e: - error_msg = "'utf-8' codec can't encode character '\\ud800' in position 0: surrogates not allowed" - assert str(e) == error_msg - "# + "take_str('\\ud800')", + PyUnicodeEncodeError, + "argument '_s': 'utf-8' codec can't encode character '\\ud800' in position 0: surrogates not allowed" ); }