Skip to content

Commit d065edf

Browse files
gh-60191: Implement ast.compare (#19211)
* bpo-15987: Implement ast.compare Add a compare() function that compares two ASTs for structural equality. There are two set of attributes on AST node objects, fields and attributes. The fields are always compared, since they represent the actual structure of the code. The attributes can be optionally be included in the comparison. Attributes capture things like line numbers of column offsets, so comparing them involves test whether the layout of the program text is the same. Since whitespace seems inessential for comparing ASTs, the default is to compare fields but not attributes. ASTs are just Python objects that can be modified in arbitrary ways. The API for ASTs is under-specified in the presence of user modifications to objects. The comparison respects modifications to fields and attributes, and to _fields and _attributes attributes. A user could create obviously malformed objects, and the code will probably fail with an AttributeError when that happens. (For example, adding "spam" to _fields but not adding a "spam" attribute to the object.) Co-authored-by: Jeremy Hylton <[email protected]>
1 parent 0e3c8cd commit d065edf

File tree

5 files changed

+210
-5
lines changed

5 files changed

+210
-5
lines changed

Doc/library/ast.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2472,6 +2472,20 @@ effects on the compilation of a program:
24722472
.. versionadded:: 3.8
24732473

24742474

2475+
.. function:: compare(a, b, /, *, compare_attributes=False)
2476+
2477+
Recursively compares two ASTs.
2478+
2479+
*compare_attributes* affects whether AST attributes are considered
2480+
in the comparison. If *compare_attributes* is ``False`` (default), then
2481+
attributes are ignored. Otherwise they must all be equal. This
2482+
option is useful to check whether the ASTs are structurally equal but
2483+
differ in whitespace or similar details. Attributes include line numbers
2484+
and column offsets.
2485+
2486+
.. versionadded:: 3.14
2487+
2488+
24752489
.. _ast-cli:
24762490

24772491
Command-Line Usage

Doc/whatsnew/3.14.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ New Modules
8686
Improved Modules
8787
================
8888

89+
ast
90+
---
91+
92+
Added :func:`ast.compare` for comparing two ASTs.
93+
(Contributed by Batuhan Taskaya and Jeremy Hylton in :issue:`15987`)
94+
95+
8996

9097
Optimizations
9198
=============

Lib/ast.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,77 @@ def walk(node):
401401
yield node
402402

403403

404+
def compare(
405+
a,
406+
b,
407+
/,
408+
*,
409+
compare_attributes=False,
410+
):
411+
"""Recursively compares two ASTs.
412+
413+
compare_attributes affects whether AST attributes are considered
414+
in the comparison. If compare_attributes is False (default), then
415+
attributes are ignored. Otherwise they must all be equal. This
416+
option is useful to check whether the ASTs are structurally equal but
417+
might differ in whitespace or similar details.
418+
"""
419+
420+
def _compare(a, b):
421+
# Compare two fields on an AST object, which may themselves be
422+
# AST objects, lists of AST objects, or primitive ASDL types
423+
# like identifiers and constants.
424+
if isinstance(a, AST):
425+
return compare(
426+
a,
427+
b,
428+
compare_attributes=compare_attributes,
429+
)
430+
elif isinstance(a, list):
431+
# If a field is repeated, then both objects will represent
432+
# the value as a list.
433+
if len(a) != len(b):
434+
return False
435+
for a_item, b_item in zip(a, b):
436+
if not _compare(a_item, b_item):
437+
return False
438+
else:
439+
return True
440+
else:
441+
return type(a) is type(b) and a == b
442+
443+
def _compare_fields(a, b):
444+
if a._fields != b._fields:
445+
return False
446+
for field in a._fields:
447+
a_field = getattr(a, field)
448+
b_field = getattr(b, field)
449+
if not _compare(a_field, b_field):
450+
return False
451+
else:
452+
return True
453+
454+
def _compare_attributes(a, b):
455+
if a._attributes != b._attributes:
456+
return False
457+
# Attributes are always ints.
458+
for attr in a._attributes:
459+
a_attr = getattr(a, attr)
460+
b_attr = getattr(b, attr)
461+
if a_attr != b_attr:
462+
return False
463+
else:
464+
return True
465+
466+
if type(a) is not type(b):
467+
return False
468+
if not _compare_fields(a, b):
469+
return False
470+
if compare_attributes and not _compare_attributes(a, b):
471+
return False
472+
return True
473+
474+
404475
class NodeVisitor(object):
405476
"""
406477
A node visitor base class that walks the abstract syntax tree and calls a

Lib/test/test_ast.py

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def to_tuple(t):
3838
result.append(to_tuple(getattr(t, f)))
3939
return tuple(result)
4040

41+
STDLIB = os.path.dirname(ast.__file__)
42+
STDLIB_FILES = [fn for fn in os.listdir(STDLIB) if fn.endswith(".py")]
43+
STDLIB_FILES.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
4144

4245
# These tests are compiled through "exec"
4346
# There should be at least one test per statement
@@ -1066,6 +1069,114 @@ def test_ast_asdl_signature(self):
10661069
expressions[0] = f"expr = {ast.expr.__subclasses__()[0].__doc__}"
10671070
self.assertCountEqual(ast.expr.__doc__.split("\n"), expressions)
10681071

1072+
def test_compare_basics(self):
1073+
self.assertTrue(ast.compare(ast.parse("x = 10"), ast.parse("x = 10")))
1074+
self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("")))
1075+
self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("x")))
1076+
self.assertFalse(
1077+
ast.compare(ast.parse("x = 10;y = 20"), ast.parse("class C:pass"))
1078+
)
1079+
1080+
def test_compare_modified_ast(self):
1081+
# The ast API is a bit underspecified. The objects are mutable,
1082+
# and even _fields and _attributes are mutable. The compare() does
1083+
# some simple things to accommodate mutability.
1084+
a = ast.parse("m * x + b", mode="eval")
1085+
b = ast.parse("m * x + b", mode="eval")
1086+
self.assertTrue(ast.compare(a, b))
1087+
1088+
a._fields = a._fields + ("spam",)
1089+
a.spam = "Spam"
1090+
self.assertNotEqual(a._fields, b._fields)
1091+
self.assertFalse(ast.compare(a, b))
1092+
self.assertFalse(ast.compare(b, a))
1093+
1094+
b._fields = a._fields
1095+
b.spam = a.spam
1096+
self.assertTrue(ast.compare(a, b))
1097+
self.assertTrue(ast.compare(b, a))
1098+
1099+
b._attributes = b._attributes + ("eggs",)
1100+
b.eggs = "eggs"
1101+
self.assertNotEqual(a._attributes, b._attributes)
1102+
self.assertFalse(ast.compare(a, b, compare_attributes=True))
1103+
self.assertFalse(ast.compare(b, a, compare_attributes=True))
1104+
1105+
a._attributes = b._attributes
1106+
a.eggs = b.eggs
1107+
self.assertTrue(ast.compare(a, b, compare_attributes=True))
1108+
self.assertTrue(ast.compare(b, a, compare_attributes=True))
1109+
1110+
def test_compare_literals(self):
1111+
constants = (
1112+
-20,
1113+
20,
1114+
20.0,
1115+
1,
1116+
1.0,
1117+
True,
1118+
0,
1119+
False,
1120+
frozenset(),
1121+
tuple(),
1122+
"ABCD",
1123+
"abcd",
1124+
"中文字",
1125+
1e1000,
1126+
-1e1000,
1127+
)
1128+
for next_index, constant in enumerate(constants[:-1], 1):
1129+
next_constant = constants[next_index]
1130+
with self.subTest(literal=constant, next_literal=next_constant):
1131+
self.assertTrue(
1132+
ast.compare(ast.Constant(constant), ast.Constant(constant))
1133+
)
1134+
self.assertFalse(
1135+
ast.compare(
1136+
ast.Constant(constant), ast.Constant(next_constant)
1137+
)
1138+
)
1139+
1140+
same_looking_literal_cases = [
1141+
{1, 1.0, True, 1 + 0j},
1142+
{0, 0.0, False, 0 + 0j},
1143+
]
1144+
for same_looking_literals in same_looking_literal_cases:
1145+
for literal in same_looking_literals:
1146+
for same_looking_literal in same_looking_literals - {literal}:
1147+
self.assertFalse(
1148+
ast.compare(
1149+
ast.Constant(literal),
1150+
ast.Constant(same_looking_literal),
1151+
)
1152+
)
1153+
1154+
def test_compare_fieldless(self):
1155+
self.assertTrue(ast.compare(ast.Add(), ast.Add()))
1156+
self.assertFalse(ast.compare(ast.Sub(), ast.Add()))
1157+
1158+
def test_compare_modes(self):
1159+
for mode, sources in (
1160+
("exec", exec_tests),
1161+
("eval", eval_tests),
1162+
("single", single_tests),
1163+
):
1164+
for source in sources:
1165+
a = ast.parse(source, mode=mode)
1166+
b = ast.parse(source, mode=mode)
1167+
self.assertTrue(
1168+
ast.compare(a, b), f"{ast.dump(a)} != {ast.dump(b)}"
1169+
)
1170+
1171+
def test_compare_attributes_option(self):
1172+
def parse(a, b):
1173+
return ast.parse(a), ast.parse(b)
1174+
1175+
a, b = parse("2 + 2", "2+2")
1176+
self.assertTrue(ast.compare(a, b))
1177+
self.assertTrue(ast.compare(a, b, compare_attributes=False))
1178+
self.assertFalse(ast.compare(a, b, compare_attributes=True))
1179+
10691180
def test_positional_only_feature_version(self):
10701181
ast.parse('def foo(x, /): ...', feature_version=(3, 8))
10711182
ast.parse('def bar(x=1, /): ...', feature_version=(3, 8))
@@ -1222,6 +1333,7 @@ def test_none_checks(self) -> None:
12221333
for node, attr, source in tests:
12231334
self.assert_none_check(node, attr, source)
12241335

1336+
12251337
class ASTHelpers_Test(unittest.TestCase):
12261338
maxDiff = None
12271339

@@ -2191,16 +2303,15 @@ def test_nameconstant(self):
21912303

21922304
@support.requires_resource('cpu')
21932305
def test_stdlib_validates(self):
2194-
stdlib = os.path.dirname(ast.__file__)
2195-
tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")]
2196-
tests.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
2197-
for module in tests:
2306+
for module in STDLIB_FILES:
21982307
with self.subTest(module):
2199-
fn = os.path.join(stdlib, module)
2308+
fn = os.path.join(STDLIB, module)
22002309
with open(fn, "r", encoding="utf-8") as fp:
22012310
source = fp.read()
22022311
mod = ast.parse(source, fn)
22032312
compile(mod, fn, "exec")
2313+
mod2 = ast.parse(source, fn)
2314+
self.assertTrue(ast.compare(mod, mod2))
22042315

22052316
constant_1 = ast.Constant(1)
22062317
pattern_1 = ast.MatchValue(constant_1)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Implemented :func:`ast.compare` for comparing two ASTs. Patch by Batuhan
2+
Taskaya with some help from Jeremy Hylton.

0 commit comments

Comments
 (0)