diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index 6afb66ef309..a3d2f4ad0ed 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -475,10 +475,14 @@ fn impl_arg_param( let ty = arg.ty; let name = arg.name; + let transform_error = quote! { + |e| pyo3::derive_utils::argument_extraction_error(_py, stringify!(#name), e) + }; if spec.is_args(&name) { return quote! { - let #arg_name = <#ty as pyo3::FromPyObject>::extract(_args.as_ref())?; + let #arg_name = <#ty as pyo3::FromPyObject>::extract(_args.as_ref()) + .map_err(#transform_error)?; }; } else if spec.is_kwargs(&name) { return quote! { @@ -518,7 +522,7 @@ fn impl_arg_param( quote! { let #mut_ _tmp: #target_ty = match #arg_value { - Some(_obj) => _obj.extract()?, + Some(_obj) => _obj.extract().map_err(#transform_error)?, None => #default, }; let #arg_name = #borrow_tmp; @@ -526,7 +530,7 @@ fn impl_arg_param( } else { quote! { let #arg_name = match #arg_value { - Some(_obj) => _obj.extract()?, + Some(_obj) => _obj.extract().map_err(#transform_error)?, None => #default, }; } diff --git a/src/derive_utils.rs b/src/derive_utils.rs index 52e22fac2f5..e9fff47173e 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -111,6 +111,18 @@ pub fn parse_fn_args<'p>( Ok((args, kwargs)) } +/// Add the argument name to the error message of an error which occurred during argument extraction +pub fn argument_extraction_error( + py: Python, + arg_name: &str, + original_error: PyErr, +) -> PyErr { + PyErr::from_type( + original_error.ptype(py), + format!("argument '{}': {}", arg_name, original_error.instance(py)), + ) +} + /// `Sync` wrapper of `ffi::PyModuleDef`. #[doc(hidden)] pub struct ModuleDef(UnsafeCell); diff --git a/tests/test_pyfunction.rs b/tests/test_pyfunction.rs index 6f6bf44c077..ae7ccb10a09 100644 --- a/tests/test_pyfunction.rs +++ b/tests/test_pyfunction.rs @@ -5,6 +5,61 @@ use pyo3::{raw_pycfunction, wrap_pyfunction}; mod common; +#[pyfunction] +fn conversion_error(str_arg: &str, int_arg: i64) { + println!("{:?} {:?}", str_arg, int_arg); +} + +#[test] +fn test_conversion_error() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let conversion_error = wrap_pyfunction!(conversion_error)(py).unwrap(); + py_expect_exception!( + py, + conversion_error, + "conversion_error('100, -100)", + PyTypeError, + "argument 'str_arg': Can't convert 100 to PyString" + ); + py_expect_exception!( + py, + conversion_error, + "conversion_error('a string', 'another string')", + PyTypeError, + "argument 'int_arg': 'str' object cannot be interpreted as an integer" + ); +} + +#[pyfunction] +#[text_signature = "(arg1, arg2)"] +fn conversion_error_signature(tuple_arg: (&str, f64), option_arg: Option) { + println!("{:?} {:?}", tuple_arg, option_arg); +} + +#[test] +fn test_conversion_error_signature() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let conversion_error_signature = wrap_pyfunction!(conversion_error_signature)(py).unwrap(); + py_expect_exception!( + py, + conversion_error_signature, + "conversion_error_signature('a string', 'another string')", + PyTypeError, + "argument 'arg1': Can't convert 'a string' to PyTuple" + ); + py_expect_exception!( + py, + conversion_error_signature, + "conversion_error_signature('100, -100)", + PyTypeError, + "argument 'arg2': Can't convert '-100' to Option" + ); +} + #[pyfunction(arg = "true")] fn optional_bool(arg: Option) -> String { format!("{:?}", arg) diff --git a/tests/test_string.rs b/tests/test_string.rs index 38d375b5418..4e1ea4e1b07 100644 --- a/tests/test_string.rs +++ b/tests/test_string.rs @@ -1,5 +1,4 @@ use pyo3::prelude::*; -use pyo3::py_run; use pyo3::wrap_pyfunction; mod common; @@ -15,15 +14,11 @@ fn test_unicode_encode_error() { let py = gil.python(); let take_str = wrap_pyfunction!(take_str)(py).unwrap(); - py_run!( + py_expect_exception!( py, take_str, - r#" - try: - take_str('\ud800') - except UnicodeEncodeError as e: - error_msg = "'utf-8' codec can't encode character '\\ud800' in position 0: surrogates not allowed" - assert str(e) == error_msg - "# + "take_str('\\ud800')", + PyUnicodeEncodeError, + "argument '_s': 'utf-8' codec can't encode character '\\ud800' in position 0: surrogates not allowed" ); }