Skip to content

Commit

Permalink
Improve robustness of compare() in the face of user-modification of AST.
Browse files Browse the repository at this point in the history
The comparison fundamentally depends on _fields and _attributes, which
could be modified by the user. It's not clear that such modifications
are sensible or supported by the API, but we can at least make sure
comparison doesn't silently ignore those comparisons.

Also pass a and b as arguments to helper methods instead of using them
from the enclosing scope.
  • Loading branch information
jeremyhylton committed May 21, 2024
1 parent 69be6d2 commit bdd2d66
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
14 changes: 8 additions & 6 deletions Lib/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,9 @@ def _compare(a, b):
else:
return type(a) is type(b) and a == b

def _compare_fields():
def _compare_fields(a, b):
if a._fields != b._fields:
return False
for field in a._fields:
a_field = getattr(a, field)
b_field = getattr(b, field)
Expand All @@ -450,7 +452,9 @@ def _compare_fields():
else:
return True

def _compare_attributes():
def _compare_attributes(a, b):
if a._attributes != b._attributes:
return False
# Attributes are always strings.
for attr in a._attributes:
a_attr = getattr(a, attr)
Expand All @@ -462,11 +466,9 @@ def _compare_attributes():

if type(a) is not type(b):
return False
# a and b are guaranteed to have the same type, so they must also
# have identical values for _fields and _attributes.
if not _compare_fields():
if not _compare_fields(a, b):
return False
if compare_attributes and not _compare_attributes():
if compare_attributes and not _compare_attributes(a, b):
return False
return True

Expand Down
30 changes: 30 additions & 0 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,36 @@ def test_compare_basics(self):
ast.compare(ast.parse("x = 10;y = 20"), ast.parse("class C:pass"))
)

def test_compare_modified_ast(self):
# The ast API is a bit underspecified. The objects are mutable,
# and even _fields and _attributes are mutable. The compare() does
# some simple things to accommodate mutability.
a = ast.parse("m * x + b", mode="eval")
b = ast.parse("m * x + b", mode="eval")
self.assertTrue(ast.compare(a, b))

a._fields = a._fields + ("spam",)
a.spam = "Spam"
self.assertNotEqual(a._fields, b._fields)
self.assertFalse(ast.compare(a, b))
self.assertFalse(ast.compare(b, a))

b._fields = a._fields
b.spam = "Spam"
self.assertTrue(ast.compare(a, b))
self.assertTrue(ast.compare(b, a))

b._attributes = b._attributes + ("eggs",)
b.eggs = "eggs"
self.assertNotEqual(a._attributes, b._attributes)
self.assertFalse(ast.compare(a, b, compare_attributes=True))
self.assertFalse(ast.compare(b, a, compare_attributes=True))

a._attributes = b._attributes
a.eggs = b.eggs
self.assertTrue(ast.compare(a, b, compare_attributes=True))
self.assertTrue(ast.compare(b, a, compare_attributes=True))

def test_compare_literals(self):
constants = (
-20,
Expand Down

0 comments on commit bdd2d66

Please sign in to comment.