Skip to content

Commit

Permalink
Preserve source positions for assertion rewriting (#12867)
Browse files Browse the repository at this point in the history
Closes #12818

(cherry picked from commit fb74025)

Co-authored-by: Frank Hoffmann <[email protected]>
  • Loading branch information
patchback[bot] and 15r10nk authored Oct 9, 2024
1 parent 3d3ec57 commit 40741c4
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 11 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ Feng Ma
Florian Bruhin
Florian Dahlitz
Floris Bruynooghe
Frank Hoffmann
Fraser Stark
Gabriel Landau
Gabriel Reis
Expand Down
1 change: 1 addition & 0 deletions changelog/12818.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Assertion rewriting now preserves the source ranges of the original instructions, making it play well with tools that deal with the ``AST``, like `executing <https://github.com/alexmojaki/executing>`__.
21 changes: 14 additions & 7 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def assign(self, expr: ast.expr) -> ast.Name:
"""Give *expr* a name."""
name = self.variable()
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
return ast.Name(name, ast.Load())
return ast.copy_location(ast.Name(name, ast.Load()), expr)

def display(self, expr: ast.expr) -> ast.expr:
"""Call saferepr on the expression."""
Expand Down Expand Up @@ -975,7 +975,10 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
# Fix locations (line numbers/column offsets).
for stmt in self.statements:
for node in traverse_node(stmt):
ast.copy_location(node, assert_)
if getattr(node, "lineno", None) is None:
# apply the assertion location to all generated ast nodes without source location
# and preserve the location of existing nodes or generated nodes with an correct location.
ast.copy_location(node, assert_)
return self.statements

def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
Expand Down Expand Up @@ -1052,15 +1055,17 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
def visit_UnaryOp(self, unary: ast.UnaryOp) -> tuple[ast.Name, str]:
pattern = UNARY_MAP[unary.op.__class__]
operand_res, operand_expl = self.visit(unary.operand)
res = self.assign(ast.UnaryOp(unary.op, operand_res))
res = self.assign(ast.copy_location(ast.UnaryOp(unary.op, operand_res), unary))
return res, pattern % (operand_expl,)

def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]:
symbol = BINOP_MAP[binop.op.__class__]
left_expr, left_expl = self.visit(binop.left)
right_expr, right_expl = self.visit(binop.right)
explanation = f"({left_expl} {symbol} {right_expl})"
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
res = self.assign(
ast.copy_location(ast.BinOp(left_expr, binop.op, right_expr), binop)
)
return res, explanation

def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
Expand Down Expand Up @@ -1089,7 +1094,7 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
arg_expls.append("**" + expl)

expl = "{}({})".format(func_expl, ", ".join(arg_expls))
new_call = ast.Call(new_func, new_args, new_kwargs)
new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call)
res = self.assign(new_call)
res_expl = self.explanation_param(self.display(res))
outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
Expand All @@ -1105,7 +1110,9 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
value, value_expl = self.visit(attr.value)
res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
res = self.assign(
ast.copy_location(ast.Attribute(value, attr.attr, ast.Load()), attr)
)
res_expl = self.explanation_param(self.display(res))
pat = "%s\n{%s = %s.%s\n}"
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
Expand Down Expand Up @@ -1146,7 +1153,7 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
syms.append(ast.Constant(sym))
expl = f"{left_expl} {sym} {next_expl}"
expls.append(ast.Constant(expl))
res_expr = ast.Compare(left_res, [op], [next_res])
res_expr = ast.copy_location(ast.Compare(left_res, [op], [next_res]), comp)
self.statements.append(ast.Assign([store_names[i]], res_expr))
left_res, left_expl = next_res, next_expl
# Use pytest.assertion.util._reprcompare if that's available.
Expand Down
211 changes: 207 additions & 4 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from __future__ import annotations

import ast
import dis
import errno
from functools import partial
import glob
import importlib
import inspect
import marshal
import os
from pathlib import Path
Expand Down Expand Up @@ -131,10 +133,211 @@ def test_location_is_set(self) -> None:
continue
for n in [node, *ast.iter_child_nodes(node)]:
assert isinstance(n, (ast.stmt, ast.expr))
assert n.lineno == 3
assert n.col_offset == 0
assert n.end_lineno == 6
assert n.end_col_offset == 3
for location in [
(n.lineno, n.col_offset),
(n.end_lineno, n.end_col_offset),
]:
assert (3, 0) <= location <= (6, 3)

def test_positions_are_preserved(self) -> None:
"""Ensure AST positions are preserved during rewriting (#12818)."""

def preserved(code: str) -> None:
s = textwrap.dedent(code)
locations = []

def loc(msg: str | None = None) -> None:
frame = inspect.currentframe()
assert frame
frame = frame.f_back
assert frame
frame = frame.f_back
assert frame

offset = frame.f_lasti

instructions = {i.offset: i for i in dis.get_instructions(frame.f_code)}

# skip CACHE instructions
while offset not in instructions and offset >= 0:
offset -= 1

instruction = instructions[offset]
if sys.version_info >= (3, 11):
position = instruction.positions
else:
position = instruction.starts_line

locations.append((msg, instruction.opname, position))

globals = {"loc": loc}

m = rewrite(s)
mod = compile(m, "<string>", "exec")
exec(mod, globals, globals)
transformed_locations = locations
locations = []

mod = compile(s, "<string>", "exec")
exec(mod, globals, globals)
original_locations = locations

assert len(original_locations) > 0
assert original_locations == transformed_locations

preserved("""
def f():
loc()
return 8
assert f() in [8]
assert (f()
in
[8])
""")

preserved("""
class T:
def __init__(self):
loc("init")
def __getitem__(self,index):
loc("getitem")
return index
assert T()[5] == 5
assert (T
()
[5]
==
5)
""")

for name, op in [
("pos", "+"),
("neg", "-"),
("invert", "~"),
]:
preserved(f"""
class T:
def __{name}__(self):
loc("{name}")
return "{name}"
assert {op}T() == "{name}"
assert ({op}
T
()
==
"{name}")
""")

for name, op in [
("add", "+"),
("sub", "-"),
("mul", "*"),
("truediv", "/"),
("floordiv", "//"),
("mod", "%"),
("pow", "**"),
("lshift", "<<"),
("rshift", ">>"),
("or", "|"),
("xor", "^"),
("and", "&"),
("matmul", "@"),
]:
preserved(f"""
class T:
def __{name}__(self,other):
loc("{name}")
return other
def __r{name}__(self,other):
loc("r{name}")
return other
assert T() {op} 2 == 2
assert 2 {op} T() == 2
assert (T
()
{op}
2
==
2)
assert (2
{op}
T
()
==
2)
""")

for name, op in [
("eq", "=="),
("ne", "!="),
("lt", "<"),
("le", "<="),
("gt", ">"),
("ge", ">="),
]:
preserved(f"""
class T:
def __{name}__(self,other):
loc()
return True
assert T() {op} 5
assert (T
()
{op}
5)
""")

for name, op in [
("eq", "=="),
("ne", "!="),
("lt", ">"),
("le", ">="),
("gt", "<"),
("ge", "<="),
("contains", "in"),
]:
preserved(f"""
class T:
def __{name}__(self,other):
loc()
return True
assert 5 {op} T()
assert (5
{op}
T
())
""")

preserved("""
def func(value):
loc("func")
return value
class T:
def __iter__(self):
loc("iter")
return iter([5])
assert func(*T()) == 5
""")

preserved("""
class T:
def __getattr__(self,name):
loc()
return name
assert T().attr == "attr"
""")

def test_dont_rewrite(self) -> None:
s = """'PYTEST_DONT_REWRITE'\nassert 14"""
Expand Down

0 comments on commit 40741c4

Please sign in to comment.