Skip to content

Commit

Permalink
Merge pull request #1494 from PyO3/enhance-py-run
Browse files Browse the repository at this point in the history
Extend py_run! to take locals dict and refactor tests using it
  • Loading branch information
kngwyu authored Mar 17, 2021
2 parents 6137e3a + 9b88a45 commit acf7271
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 210 deletions.
56 changes: 41 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ macro_rules! wrap_pymodule {
/// # Example
/// ```
/// use pyo3::{prelude::*, py_run, types::PyList};
/// let gil = Python::acquire_gil();
/// let py = gil.python();
/// let list = PyList::new(py, &[1, 2, 3]);
/// py_run!(py, list, "assert list == [1, 2, 3]");
/// Python::with_gil(|py| {
/// let list = PyList::new(py, &[1, 2, 3]);
/// py_run!(py, list, "assert list == [1, 2, 3]");
/// });
/// ```
///
/// You can use this macro to test pyfunctions or pyclasses quickly.
Expand Down Expand Up @@ -320,15 +320,33 @@ macro_rules! wrap_pymodule {
/// (self.hour, self.minute, self.second)
/// }
/// }
/// let gil = Python::acquire_gil();
/// let py = gil.python();
/// let time = PyCell::new(py, Time {hour: 8, minute: 43, second: 16}).unwrap();
/// let time_as_tuple = (8, 43, 16);
/// py_run!(py, time time_as_tuple, r#"
/// assert time.hour == 8
/// assert time.repl_japanese() == "8時43分16秒"
/// assert time.as_tuple() == time_as_tuple
/// "#);
/// Python::with_gil(|py| {
/// let time = PyCell::new(py, Time {hour: 8, minute: 43, second: 16}).unwrap();
/// let time_as_tuple = (8, 43, 16);
/// py_run!(py, time time_as_tuple, r#"
/// assert time.hour == 8
/// assert time.repl_japanese() == "8時43分16秒"
/// assert time.as_tuple() == time_as_tuple
/// "#);
/// });
/// ```
///
/// If you need to prepare the `locals` dict by yourself, you can pass it as `*locals`.
///
/// ```
/// use pyo3::prelude::*;
/// use pyo3::types::IntoPyDict;
/// #[pyclass]
/// struct MyClass {}
/// #[pymethods]
/// impl MyClass {
/// #[new]
/// fn new() -> Self { MyClass {} }
/// }
/// Python::with_gil(|py| {
/// let locals = [("C", py.get_type::<MyClass>())].into_py_dict(py);
/// pyo3::py_run!(py, *locals, "c = C()");
/// });
/// ```
///
/// **Note**
Expand All @@ -345,6 +363,12 @@ macro_rules! py_run {
($py:expr, $($val:ident)+, $code:expr) => {{
$crate::py_run_impl!($py, $($val)+, &$crate::unindent::unindent($code))
}};
($py:expr, *$dict:expr, $code:literal) => {{
$crate::py_run_impl!($py, *$dict, $crate::indoc::indoc!($code))
}};
($py:expr, *$dict:expr, $code:expr) => {{
$crate::py_run_impl!($py, *$dict, &$crate::unindent::unindent($code))
}};
}

#[macro_export]
Expand All @@ -355,8 +379,10 @@ macro_rules! py_run_impl {
use $crate::types::IntoPyDict;
use $crate::ToPyObject;
let d = [$((stringify!($val), $val.to_object($py)),)+].into_py_dict($py);

if let Err(e) = $py.run($code, None, Some(d)) {
$crate::py_run_impl!($py, *d, $code)
}};
($py:expr, *$dict:expr, $code:expr) => {{
if let Err(e) = $py.run($code, None, Some($dict)) {
e.print($py);
// So when this c api function the last line called printed the error to stderr,
// the output is only written into a buffer which is never flushed because we
Expand Down
44 changes: 25 additions & 19 deletions tests/common.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,43 @@
//! Useful tips for writing tests:
//! - Tests are run in parallel; There's still a race condition in test_owned with some other test
//! - You need to use flush=True to get any output from print
//! Some common macros for tests
#[macro_export]
macro_rules! py_assert {
($py:expr, $val:ident, $assertion:expr) => {
pyo3::py_run!($py, $val, concat!("assert ", $assertion))
($py:expr, $($val:ident)+, $assertion:literal) => {
pyo3::py_run!($py, $($val)+, concat!("assert ", $assertion))
};
($py:expr, *$dict:expr, $assertion:literal) => {
pyo3::py_run!($py, *$dict, concat!("assert ", $assertion))
};
}

#[macro_export]
macro_rules! py_expect_exception {
($py:expr, $val:ident, $code:expr, $err:ident) => {{
// Case1: idents & no err_msg
($py:expr, $($val:ident)+, $code:expr, $err:ident) => {{
use pyo3::types::IntoPyDict;
let d = [(stringify!($val), &$val)].into_py_dict($py);

let res = $py.run($code, None, Some(d));
let d = [$((stringify!($val), $val.to_object($py)),)+].into_py_dict($py);
py_expect_exception!($py, *d, $code, $err)
}};
// Case2: dict & no err_msg
($py:expr, *$dict:expr, $code:expr, $err:ident) => {{
let res = $py.run($code, None, Some($dict));
let err = res.expect_err(&format!("Did not raise {}", stringify!($err)));
if !err.matches($py, $py.get_type::<pyo3::exceptions::$err>()) {
panic!("Expected {} but got {:?}", stringify!($err), err)
}
err
}};
($py:expr, $val:ident, $code:expr, $err:ident, $err_msg:expr) => {{
let err = py_expect_exception!($py, $val, $code, $err);
assert_eq!(
err.instance($py)
.str()
.expect("error str() failed")
.to_str()
.expect("message was not valid utf8"),
$err_msg
);
// Case3: idents & err_msg
($py:expr, $($val:ident)+, $code:expr, $err:ident, $err_msg:literal) => {{
let err = py_expect_exception!($py, $($val)+, $code, $err);
// Suppose that the error message looks like 'TypeError: ~'
assert_eq!(format!("Py{}", err), concat!(stringify!($err), ": ", $err_msg));
err
}};
// Case4: dict & err_msg
($py:expr, *$dict:expr, $code:expr, $err:ident, $err_msg:literal) => {{
let err = py_expect_exception!($py, *$dict, $code, $err);
assert_eq!(format!("Py{}", err), concat!(stringify!($err), ": ", $err_msg));
err
}};
}
5 changes: 3 additions & 2 deletions tests/test_buffer_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

mod common;

#[pyclass]
struct TestBufferClass {
vec: Vec<u8>,
Expand Down Expand Up @@ -93,8 +95,7 @@ fn test_buffer() {
)
.unwrap();
let env = [("ob", instance)].into_py_dict(py);
py.run("assert bytes(ob) == b' 23'", None, Some(env))
.unwrap();
py_assert!(py, *env, "bytes(ob) == b' 23'");
}

assert!(drop_called.load(Ordering::Relaxed));
Expand Down
8 changes: 3 additions & 5 deletions tests/test_dunder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pyo3::class::{
};
use pyo3::exceptions::{PyIndexError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PySlice, PyType};
use pyo3::types::{PySlice, PyType};
use pyo3::{ffi, py_run, AsPyPointer, PyCell};
use std::convert::TryFrom;
use std::{isize, iter};
Expand Down Expand Up @@ -450,11 +450,9 @@ fn test_cls_impl() {
let py = gil.python();

let ob = Py::new(py, Test {}).unwrap();
let d = [("ob", ob)].into_py_dict(py);

py.run("assert ob[1] == 'int'", None, Some(d)).unwrap();
py.run("assert ob[100:200:1] == 'slice'", None, Some(d))
.unwrap();
py_assert!(py, ob, "ob[1] == 'int'");
py_assert!(py, ob, "ob[100:200:1] == 'slice'");
}

#[pyclass(dict, subclass)]
Expand Down
7 changes: 1 addition & 6 deletions tests/test_getter_setter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,7 @@ fn class_with_properties() {
py_run!(py, inst, "assert inst.data_list == [42]");

let d = [("C", py.get_type::<ClassWithProperties>())].into_py_dict(py);
py.run(
"assert C.DATA.__doc__ == 'a getter for data'",
None,
Some(d),
)
.unwrap();
py_assert!(py, *d, "C.DATA.__doc__ == 'a getter for data'");
}

#[pyclass]
Expand Down
60 changes: 30 additions & 30 deletions tests/test_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ use std::collections::HashMap;

use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
use pyo3::py_run;
use pyo3::types::IntoPyDict;
use pyo3::types::PyList;
use pyo3::PyMappingProtocol;

mod common;

#[pyclass]
struct Mapping {
index: HashMap<String, usize>,
Expand Down Expand Up @@ -66,61 +69,58 @@ impl PyMappingProtocol for Mapping {
}
}

/// Return a dict with `m = Mapping(['1', '2', '3'])`.
fn map_dict(py: Python) -> &pyo3::types::PyDict {
let d = [("Mapping", py.get_type::<Mapping>())].into_py_dict(py);
py_run!(py, *d, "m = Mapping(['1', '2', '3'])");
d
}

#[test]
fn test_getitem() {
let gil = Python::acquire_gil();
let py = gil.python();
let d = [("Mapping", py.get_type::<Mapping>())].into_py_dict(py);

let run = |code| py.run(code, None, Some(d)).unwrap();
let err = |code| py.run(code, None, Some(d)).unwrap_err();
let d = map_dict(py);

run("m = Mapping(['1', '2', '3']); assert m['1'] == 0");
run("m = Mapping(['1', '2', '3']); assert m['2'] == 1");
run("m = Mapping(['1', '2', '3']); assert m['3'] == 2");
err("m = Mapping(['1', '2', '3']); print(m['4'])");
py_assert!(py, *d, "m['1'] == 0");
py_assert!(py, *d, "m['2'] == 1");
py_assert!(py, *d, "m['3'] == 2");
py_expect_exception!(py, *d, "print(m['4'])", PyKeyError);
}

#[test]
fn test_setitem() {
let gil = Python::acquire_gil();
let py = gil.python();
let d = [("Mapping", py.get_type::<Mapping>())].into_py_dict(py);

let run = |code| py.run(code, None, Some(d)).unwrap();
let err = |code| py.run(code, None, Some(d)).unwrap_err();
let d = map_dict(py);

run("m = Mapping(['1', '2', '3']); m['1'] = 4; assert m['1'] == 4");
run("m = Mapping(['1', '2', '3']); m['0'] = 0; assert m['0'] == 0");
run("m = Mapping(['1', '2', '3']); len(m) == 4");
err("m = Mapping(['1', '2', '3']); m[0] = 'hello'");
err("m = Mapping(['1', '2', '3']); m[0] = -1");
py_run!(py, *d, "m['1'] = 4; assert m['1'] == 4");
py_run!(py, *d, "m['0'] = 0; assert m['0'] == 0");
py_assert!(py, *d, "len(m) == 4");
py_expect_exception!(py, *d, "m[0] = 'hello'", PyTypeError);
py_expect_exception!(py, *d, "m[0] = -1", PyTypeError);
}

#[test]
fn test_delitem() {
let gil = Python::acquire_gil();
let py = gil.python();

let d = [("Mapping", py.get_type::<Mapping>())].into_py_dict(py);
let run = |code| py.run(code, None, Some(d)).unwrap();
let err = |code| py.run(code, None, Some(d)).unwrap_err();

run(
"m = Mapping(['1', '2', '3']); del m['1']; assert len(m) == 2; \
assert m['2'] == 1; assert m['3'] == 2",
let d = map_dict(py);
py_run!(
py,
*d,
"del m['1']; assert len(m) == 2 and m['2'] == 1 and m['3'] == 2"
);
err("m = Mapping(['1', '2', '3']); del m[-1]");
err("m = Mapping(['1', '2', '3']); del m['4']");
py_expect_exception!(py, *d, "del m[-1]", PyTypeError);
py_expect_exception!(py, *d, "del m['4']", PyKeyError);
}

#[test]
fn test_reversed() {
let gil = Python::acquire_gil();
let py = gil.python();

let d = [("Mapping", py.get_type::<Mapping>())].into_py_dict(py);
let run = |code| py.run(code, None, Some(d)).unwrap();

run("m = Mapping(['1', '2']); assert set(reversed(m)) == {'1', '2'}");
let d = map_dict(py);
py_assert!(py, *d, "set(reversed(m)) == {'1', '2', '3'}");
}
Loading

0 comments on commit acf7271

Please sign in to comment.