diff --git a/Lib/ast.py b/Lib/ast.py index ced9dfce11fec3..6ea0524e00ed1b 100644 --- a/Lib/ast.py +++ b/Lib/ast.py @@ -441,7 +441,9 @@ def _compare(a, b): else: return type(a) is type(b) and a == b - def _compare_fields(): + def _compare_fields(a, b): + if a._fields != b._fields: + return False for field in a._fields: a_field = getattr(a, field) b_field = getattr(b, field) @@ -450,7 +452,9 @@ def _compare_fields(): else: return True - def _compare_attributes(): + def _compare_attributes(a, b): + if a._attributes != b._attributes: + return False # Attributes are always strings. for attr in a._attributes: a_attr = getattr(a, attr) @@ -462,11 +466,9 @@ def _compare_attributes(): if type(a) is not type(b): return False - # a and b are guaranteed to have the same type, so they must also - # have identical values for _fields and _attributes. - if not _compare_fields(): + if not _compare_fields(a, b): return False - if compare_attributes and not _compare_attributes(): + if compare_attributes and not _compare_attributes(a, b): return False return True diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index f91075533bde5b..535a8dfe2dcf1f 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -1077,6 +1077,36 @@ def test_compare_basics(self): ast.compare(ast.parse("x = 10;y = 20"), ast.parse("class C:pass")) ) + def test_compare_modified_ast(self): + # The ast API is a bit underspecified. The objects are mutable, + # and even _fields and _attributes are mutable. The compare() does + # some simple things to accommodate mutability. + a = ast.parse("m * x + b", mode="eval") + b = ast.parse("m * x + b", mode="eval") + self.assertTrue(ast.compare(a, b)) + + a._fields = a._fields + ("spam",) + a.spam = "Spam" + self.assertNotEqual(a._fields, b._fields) + self.assertFalse(ast.compare(a, b)) + self.assertFalse(ast.compare(b, a)) + + b._fields = a._fields + b.spam = "Spam" + self.assertTrue(ast.compare(a, b)) + self.assertTrue(ast.compare(b, a)) + + b._attributes = b._attributes + ("eggs",) + b.eggs = "eggs" + self.assertNotEqual(a._attributes, b._attributes) + self.assertFalse(ast.compare(a, b, compare_attributes=True)) + self.assertFalse(ast.compare(b, a, compare_attributes=True)) + + a._attributes = b._attributes + a.eggs = b.eggs + self.assertTrue(ast.compare(a, b, compare_attributes=True)) + self.assertTrue(ast.compare(b, a, compare_attributes=True)) + def test_compare_literals(self): constants = ( -20,