Skip to content

Commit

Permalink
Add argument name to error messages in conversion errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Askaholic committed Oct 13, 2020
1 parent 007bfb7 commit fea9dfb
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 13 deletions.
10 changes: 7 additions & 3 deletions pyo3-derive-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,14 @@ fn impl_arg_param(

let ty = arg.ty;
let name = arg.name;
let transform_error = quote! {
|e| pyo3::derive_utils::argument_extraction_error(_py, stringify!(#name), e)
};

if spec.is_args(&name) {
return quote! {
let #arg_name = <#ty as pyo3::FromPyObject>::extract(_args.as_ref())?;
let #arg_name = <#ty as pyo3::FromPyObject>::extract(_args.as_ref())
.map_err(#transform_error)?;
};
} else if spec.is_kwargs(&name) {
return quote! {
Expand Down Expand Up @@ -518,15 +522,15 @@ fn impl_arg_param(

quote! {
let #mut_ _tmp: #target_ty = match #arg_value {
Some(_obj) => _obj.extract()?,
Some(_obj) => _obj.extract().map_err(#transform_error)?,
None => #default,
};
let #arg_name = #borrow_tmp;
}
} else {
quote! {
let #arg_name = match #arg_value {
Some(_obj) => _obj.extract()?,
Some(_obj) => _obj.extract().map_err(#transform_error)?,
None => #default,
};
}
Expand Down
14 changes: 13 additions & 1 deletion src/derive_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::err::{PyErr, PyResult};
use crate::exceptions::PyTypeError;
use crate::instance::PyNativeType;
use crate::pyclass::{PyClass, PyClassThreadChecker};
use crate::types::{PyAny, PyDict, PyModule, PyTuple};
use crate::types::{PyAny, PyDict, PyModule, PyString, PyTuple};
use crate::{ffi, GILPool, IntoPy, PyCell, Python};
use std::cell::UnsafeCell;

Expand Down Expand Up @@ -111,6 +111,18 @@ pub fn parse_fn_args<'p>(
Ok((args, kwargs))
}

/// Add the argument name to the error message of an error which occurred during argument extraction
pub fn argument_extraction_error(py: Python, arg_name: &str, original_error: PyErr) -> PyErr {
let reason = original_error
.instance(py)
.str()
.unwrap_or_else(|_| PyString::new(py, ""));
PyErr::from_type(
original_error.ptype(py),
format!("argument '{}': {}", arg_name, reason),
)
}

/// `Sync` wrapper of `ffi::PyModuleDef`.
#[doc(hidden)]
pub struct ModuleDef(UnsafeCell<ffi::PyModuleDef>);
Expand Down
51 changes: 51 additions & 0 deletions tests/test_pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,54 @@ fn test_raw_function() {
.unwrap();
assert_eq!(res, "Some(true)");
}

#[pyfunction]
fn conversion_error(str_arg: &str, int_arg: i64, tuple_arg: (&str, f64), option_arg: Option<i64>) {
println!(
"{:?} {:?} {:?} {:?}",
str_arg, int_arg, tuple_arg, option_arg
);
}

#[test]
fn test_conversion_error() {
let gil = Python::acquire_gil();
let py = gil.python();

let conversion_error = wrap_pyfunction!(conversion_error)(py).unwrap();
py_expect_exception!(
py,
conversion_error,
"conversion_error(None, None, None, None)",
PyTypeError,
"argument 'str_arg': Can't convert None to PyString"
);
py_expect_exception!(
py,
conversion_error,
"conversion_error(100, None, None, None)",
PyTypeError,
"argument 'str_arg': Can't convert 100 to PyString"
);
py_expect_exception!(
py,
conversion_error,
"conversion_error('string1', 'string2', None, None)",
PyTypeError,
"argument 'int_arg': 'str' object cannot be interpreted as an integer"
);
py_expect_exception!(
py,
conversion_error,
"conversion_error('string1', -100, 'string2', None)",
PyTypeError,
"argument 'tuple_arg': Can't convert 'string2' to PyTuple"
);
py_expect_exception!(
py,
conversion_error,
"conversion_error('string1', -100, ('string2', 10.), 'string3')",
PyTypeError,
"argument 'option_arg': 'str' object cannot be interpreted as an integer"
);
}
13 changes: 4 additions & 9 deletions tests/test_string.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use pyo3::prelude::*;
use pyo3::py_run;
use pyo3::wrap_pyfunction;

mod common;
Expand All @@ -15,15 +14,11 @@ fn test_unicode_encode_error() {
let py = gil.python();

let take_str = wrap_pyfunction!(take_str)(py).unwrap();
py_run!(
py_expect_exception!(
py,
take_str,
r#"
try:
take_str('\ud800')
except UnicodeEncodeError as e:
error_msg = "'utf-8' codec can't encode character '\\ud800' in position 0: surrogates not allowed"
assert str(e) == error_msg
"#
"take_str('\\ud800')",
PyUnicodeEncodeError,
"argument '_s': 'utf-8' codec can't encode character '\\ud800' in position 0: surrogates not allowed"
);
}

0 comments on commit fea9dfb

Please sign in to comment.