Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 71 additions & 5 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import ast
import dis
import os
import random
import sys
import tokenize
import unittest
import warnings
import weakref
Expand All @@ -25,6 +27,9 @@ def to_tuple(t):
result.append(to_tuple(getattr(t, f)))
return tuple(result)

STDLIB = os.path.dirname(ast.__file__)
STDLIB_FILES = [fn for fn in os.listdir(STDLIB) if fn.endswith(".py")]
STDLIB_FILES.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])

# These tests are compiled through "exec"
# There should be at least one test per statement
Expand Down Expand Up @@ -654,6 +659,70 @@ def test_ast_asdl_signature(self):
expressions[0] = f"expr = {ast.expr.__subclasses__()[0].__doc__}"
self.assertCountEqual(ast.expr.__doc__.split("\n"), expressions)

def test_compare_basis(self):
self.assertEqual(ast.parse("x = 10"), ast.parse("x = 10"))
self.assertNotEqual(ast.parse("x = 10"), ast.parse(""))
self.assertNotEqual(ast.parse("x = 10"), ast.parse("x"))
self.assertNotEqual(ast.parse("x = 10;y = 20"), ast.parse("class C:pass"))

def test_compare_literals(self):
constants = (-20, 20, 20.0, 1, 1.0, True, 0, False, frozenset(), tuple(), "ABCD", "abcd", "中文字", 1e1000, -1e1000)
for next_index, constant in enumerate(constants[:-1], 1):
next_constant = constants[next_index]
with self.subTest(literal=constant, next_literal=next_constant):
self.assertEqual(ast.Constant(constant), ast.Constant(constant))
self.assertNotEqual(ast.Constant(constant), ast.Constant(next_constant))

same_looking_literal_cases = [{1, 1.0, True, 1+0j}, {0, 0.0, False, 0+0j}]
for same_looking_literals in same_looking_literal_cases:
for literal in same_looking_literals:
for same_looking_literal in same_looking_literals - {literal}:
self.assertNotEqual(ast.Constant(literal), ast.Constant(same_looking_literal))

def test_compare_operators(self):
self.assertEqual(ast.Add(), ast.Add())
self.assertEqual(ast.Sub(), ast.Sub())

self.assertNotEqual(ast.Add(), ast.Sub())
self.assertNotEqual(ast.Add(), ast.Constant())

def test_compare_stdlib(self):
if support.is_resource_enabled("cpu"):
files = STDLIB_FILES
else:
files = random.sample(STDLIB_FILES, 10)

for module in files:
with self.subTest(module):
fn = os.path.join(STDLIB, module)
with tokenize.open(fn) as fp:
source = fp.read()
a = ast.parse(source, fn)
b = ast.parse(source, fn)
self.assertEqual(a, b, f"{ast.dump(a)} != {ast.dump(b)}")
self.assertFalse(a != b)

def test_exec_compare(self):
for source in exec_tests:
a = ast.parse(source, mode="exec")
b = ast.parse(source, mode="exec")
self.assertEqual(a, b, f"{ast.dump(a)} != {ast.dump(b)}")
self.assertFalse(a != b)

def test_single_compare(self):
for source in single_tests:
a = ast.parse(source, mode="single")
b = ast.parse(source, mode="single")
self.assertEqual(a, b, f"{ast.dump(a)} != {ast.dump(b)}")
self.assertFalse(a != b)

def test_eval_compare(self):
for source in eval_tests:
a = ast.parse(source, mode="eval")
b = ast.parse(source, mode="eval")
self.assertEqual(a, b, f"{ast.dump(a)} != {ast.dump(b)}")
self.assertFalse(a != b)


class ASTHelpers_Test(unittest.TestCase):
maxDiff = None
Expand Down Expand Up @@ -1369,12 +1438,9 @@ def test_nameconstant(self):
self.expr(ast.NameConstant(4))

def test_stdlib_validates(self):
stdlib = os.path.dirname(ast.__file__)
tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")]
tests.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
for module in tests:
for module in STDLIB_FILES:
with self.subTest(module):
fn = os.path.join(stdlib, module)
fn = os.path.join(STDLIB, module)
with open(fn, "r", encoding="utf-8") as fp:
source = fp.read()
mod = ast.parse(source, fn)
Expand Down
1 change: 1 addition & 0 deletions Misc/ACKS
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,7 @@ Jason Lowe
Tony Lownds
Ray Loyzaga
Kang-Hao (Kenny) Lu
Louie Lu
Lukas Lueg
Loren Luke
Fredrik Lundh
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Provide a way to compare AST nodes for equality recursively. Patch by Louie
Lu, Flavian Hautbois and Batuhan Taskaya.
84 changes: 84 additions & 0 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,88 @@ def visitModule(self, mod):
return Py_BuildValue("O()", Py_TYPE(self));
}

static PyObject *
ast_richcompare(PyObject *self, PyObject *other, int op)
{
Py_ssize_t i, numfields = 0;
PyObject *fields, *key = NULL;

/* Check operator */
if ((op != Py_EQ && op != Py_NE) ||
!PyAST_Check(self) || !PyAST_Check(other)) {
Py_RETURN_NOTIMPLEMENTED;
}

/* Compare types */
if (Py_TYPE(self) != Py_TYPE(other)) {
Py_RETURN_RICHCOMPARE(Py_TYPE(self), Py_TYPE(other), op);
}

if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), astmodulestate_global->_fields, &fields) < 0) {
return NULL;
}
if (fields) {
numfields = PySequence_Size(fields);
if (numfields == -1) {
goto fail;
}
}

PyObject *a, *b;
/* Compare fields */
for (i = 0; i < numfields; i++) {
key = PySequence_GetItem(fields, i);
if (!key) {
goto fail;
}
if (!PyObject_HasAttr(self, key) || !PyObject_HasAttr(other, key)) {
Py_DECREF(key);
goto unsuccessful;
}
Py_DECREF(key);

a = PyObject_GetAttr(self, key);
b = PyObject_GetAttr(other, key);
if (!a || !b) {
goto unsuccessful;
}

/* Ensure they belong to the same type */
if (Py_TYPE(a) != Py_TYPE(b)) {
goto unsuccessful;
}

if (!PyObject_RichCompareBool(a, b, Py_EQ)) {
goto unsuccessful;
}
Py_DECREF(a);
Py_DECREF(b);
}
Py_DECREF(fields);

if (op == Py_EQ) {
Py_RETURN_TRUE;
}
else {
Py_RETURN_FALSE;
}

unsuccessful:
Py_XDECREF(a);
Py_XDECREF(b);
Py_DECREF(fields);
if (op == Py_EQ) {
Py_RETURN_FALSE;
}
else {
Py_RETURN_TRUE;
}

fail:
Py_DECREF(fields);
return NULL;
}

static PyMemberDef ast_type_members[] = {
{"__dictoffset__", T_PYSSIZET, offsetof(AST_object, dict), READONLY},
{NULL} /* Sentinel */
Expand Down Expand Up @@ -770,6 +852,8 @@ def visitModule(self, mod):
{Py_tp_alloc, PyType_GenericAlloc},
{Py_tp_new, PyType_GenericNew},
{Py_tp_free, PyObject_GC_Del},
{Py_tp_richcompare, ast_richcompare},
{Py_tp_hash, (hashfunc)_Py_HashPointer},
{0, 0},
};

Expand Down
84 changes: 84 additions & 0 deletions Python/Python-ast.c

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.