Skip to content

bpo-39981: Introduce default values for AST node classes #21417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Doc/whatsnew/3.10.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ New Modules
Improved Modules
================

ast
---

Sequential and optional fields for AST nodes are now auto-initialized with the
corresponding empty values. See the :ref:`ASDL <abstract-grammar>` for more
information about the AST node classes and fields they have.
(Contributed by Batuhan Taskaya in :issue:`39981`)

base64
------

Expand Down
1 change: 1 addition & 0 deletions Include/Python-ast.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions Lib/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ def dump(node, annotate_fields=True, include_attributes=False, *, indent=None):
integer or string, then the tree will be pretty-printed with that indent
level. None (the default) selects the single line representation.
"""
def _qualifier_to_default(qualifier):
if qualifier == 1:
return []
elif qualifier == 2:
return None
else:
return ...

def _format(node, level=0):
if indent is not None:
level += 1
Expand All @@ -130,13 +138,14 @@ def _format(node, level=0):
args = []
allsimple = True
keywords = annotate_fields
for name in node._fields:
for name, qualifier in zip(node._fields, node._field_qualifiers):
default_value = _qualifier_to_default(qualifier)
try:
value = getattr(node, name)
except AttributeError:
keywords = True
continue
if value is None and getattr(cls, name, ...) is None:
if value is None and default_value is None:
keywords = True
continue
value, simple = _format(value, level)
Expand Down
37 changes: 35 additions & 2 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,19 +358,52 @@ def test_arguments(self):
self.assertEqual(x._fields, ('posonlyargs', 'args', 'vararg', 'kwonlyargs',
'kw_defaults', 'kwarg', 'defaults'))

with self.assertRaises(AttributeError):
x.args
self.assertIsNone(x.vararg)
self.assertEqual(x.args, [])

x = ast.arguments(*range(1, 8))
self.assertEqual(x.args, 2)
self.assertEqual(x.vararg, 3)

def test_field_defaults(self):
func = ast.FunctionDef("foo", ast.arguments())
self.assertEqual(func.name, "foo")
self.assertEqual(ast.dump(func.args), ast.dump(ast.arguments()))
self.assertEqual(func.body, [])
self.assertEqual(func.decorator_list, [])
self.assertEqual(func.returns, None)
self.assertEqual(func.type_comment, None)

func2 = ast.FunctionDef()
with self.assertRaises(AttributeError):
func2.name2

self.assertEqual(func.body, [])
self.assertEqual(func.returns, None)

func3 = ast.FunctionDef(body=[1])
self.assertEqual(func3.body, [1])
self.assertFalse(hasattr(func3, "name"))
self.assertTrue(hasattr(func3, "returns"))

def test_field_attr_writable(self):
x = ast.Num()
# We can assign to _fields
x._fields = 666
x._field_qualifiers = 999
self.assertEqual(x._fields, 666)
self.assertEqual(x._field_qualifiers, 999)

functiondef_qualifiers = ast.FunctionDef._field_qualifiers
del ast.FunctionDef._field_qualifiers
fnctdef = ast.FunctionDef("foo")
self.assertEqual(fnctdef.name, "foo")
with self.assertRaises(AttributeError):
fnctdef.body
ast.FunctionDef._field_qualifiers = (5,) * len(functiondef_qualifiers)
with self.assertRaises(ValueError):
ast.FunctionDef() # 5 as a field qualifier is an invalid value
ast.FunctionDef._field_qualifiers = functiondef_qualifiers

def test_classattrs(self):
x = ast.Num()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Introduce default values for AST node class initializations.
156 changes: 124 additions & 32 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ class TypeDefVisitor(EmitVisitor):
def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)

self.emit(
"typedef enum _field_qualifier {Q_SEQUENCE=1, Q_OPTIONAL=2} "
"field_qualifier;", 0
)
def visitType(self, type, depth=0):
self.visit(type.value, type.name, depth)

Expand Down Expand Up @@ -665,6 +668,7 @@ def visitProduct(self, prod, name):
for f in prod.fields:
self.emit('"%s",' % f.name, 1)
self.emit("};", 0)
self._emit_field_qualifiers(name, prod.fields, 0)

def visitSum(self, sum, name):
self.emit_type("%s_type" % name)
Expand Down Expand Up @@ -692,6 +696,19 @@ def visitConstructor(self, cons, name):
for t in cons.fields:
self.emit('"%s",' % t.name, 1)
self.emit("};",0)
self._emit_field_qualifiers(cons.name, cons.fields, 0)

def _emit_field_qualifiers(self, name, fields, depth):
self.emit("static const field_qualifier %s_field_qualifiers[]={" % name, depth)
for field in fields:
if field.seq:
qualifier = "Q_SEQUENCE"
elif field.opt:
qualifier = "Q_OPTIONAL"
else:
qualifier = "0"
self.emit("%s, // %s" % (qualifier, field.name), depth+1)
self.emit("};", depth)


class PyTypesVisitor(PickleVisitor):
Expand Down Expand Up @@ -742,7 +759,7 @@ def visitModule(self, mod):

Py_ssize_t i, numfields = 0;
int res = -1;
PyObject *key, *value, *fields;
PyObject *key, *value, *fields, *field_qualifiers = NULL;
if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
goto cleanup;
}
Expand Down Expand Up @@ -802,8 +819,67 @@ def visitModule(self, mod):
}
}
}
if (_PyObject_LookupAttr(self, state->_field_qualifiers, &field_qualifiers) < 0) {
res = -1;
goto cleanup;
}

if (!PyTuple_CheckExact(field_qualifiers) || PyTuple_Size(field_qualifiers) != numfields) {
goto cleanup;
}

PyObject *field, *field_qualifier;
for (i = 0; i < numfields; i++) {
field = PySequence_GetItem(fields, i);
field_qualifier = PySequence_GetItem(field_qualifiers, i);
if (!field_qualifier || !field) {
res = -1;
goto next_iteration;
}

if (PyObject_HasAttr(self, field)) {
goto next_iteration;
}

PyObject *field_default = NULL;
switch (PyLong_AsLong(field_qualifier)) {
case -1:
res = -1;
goto next_iteration;
case 0:
goto next_iteration;
case Q_SEQUENCE:
field_default = PyList_New(0);
if (field_default == NULL) {
res = -1;
goto next_iteration;
}
break;
case Q_OPTIONAL:
field_default = Py_None;
Py_INCREF(field_default);
break;
default:
PyErr_Format(PyExc_ValueError,
"Unknown field qualifier: \\"%R\\"", field_qualifier);
res = -1;
goto next_iteration;
}
assert(field_default != NULL);
res = PyObject_SetAttr(self, field, field_default);
Py_DECREF(field_default);
next_iteration:
Py_XDECREF(field);
Py_XDECREF(field_qualifier);
if (res < 0) {
goto cleanup;
}
continue;
}

cleanup:
Py_XDECREF(fields);
Py_XDECREF(field_qualifiers);
return res;
}

Expand Down Expand Up @@ -866,29 +942,45 @@ def visitModule(self, mod):
};

static PyObject *
make_type(astmodulestate *state, const char *type, PyObject* base,
const char* const* fields, int num_fields, const char *doc)
make_type(
astmodulestate *state,
const char *type,
PyObject* base,
const char* const* fields,
const field_qualifier* field_qualifiers,
Py_ssize_t num_fields,
const char *doc
)
{
PyObject *fnames, *result;
int i;
fnames = PyTuple_New(num_fields);
if (!fnames) return NULL;
Py_ssize_t i;
PyObject *result = NULL;
PyObject *fnames = PyTuple_New(num_fields);
PyObject *fqualifiers = PyTuple_New(num_fields);

if (!fnames || !fqualifiers) {
goto exit;
}

for (i = 0; i < num_fields; i++) {
PyObject *field = PyUnicode_InternFromString(fields[i]);
if (!field) {
Py_DECREF(fnames);
return NULL;
PyObject *qualifier = PyLong_FromLong((long)field_qualifiers[i]);
if (!field || !qualifier) {
goto exit;
}
PyTuple_SET_ITEM(fnames, i, field);
PyTuple_SET_ITEM(fqualifiers, i, qualifier);
}
result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){OOOOOs}",
result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){OOOOOOOs}",
type, base,
state->_fields, fnames,
state->_field_qualifiers, fqualifiers,
state->__module__,
state->ast,
state->__doc__, doc);
Py_DECREF(fnames);
return result;
exit:
Py_XDECREF(fnames);
Py_XDECREF(fqualifiers);
return result;
}

static int
Expand Down Expand Up @@ -1012,8 +1104,10 @@ def visitModule(self, mod):
{
PyObject *empty_tuple;
empty_tuple = PyTuple_New(0);
Py_XINCREF(empty_tuple); // for _field_qualifiers
if (!empty_tuple ||
PyObject_SetAttrString(state->AST_type, "_fields", empty_tuple) < 0 ||
PyObject_SetAttrString(state->AST_type, "_field_qualifiers", empty_tuple) < 0 ||
PyObject_SetAttrString(state->AST_type, "_attributes", empty_tuple) < 0) {
Py_XDECREF(empty_tuple);
return -1;
Expand All @@ -1040,10 +1134,11 @@ def visitModule(self, mod):
def visitProduct(self, prod, name):
if prod.fields:
fields = name+"_fields"
field_qualifiers = name+"_field_qualifiers"
else:
fields = "NULL"
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' %
(name, name, fields, len(prod.fields)), 1)
fields = field_qualifiers = "NULL"
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %s, %d,' %
(name, name, fields, field_qualifiers, len(prod.fields)), 1)
self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False)
self.emit("if (!state->%s_type) return 0;" % name, 1)
self.emit_type("AST_type")
Expand All @@ -1053,11 +1148,9 @@ def visitProduct(self, prod, name):
(name, name, len(prod.attributes)), 1)
else:
self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
self.emit_defaults(name, prod.fields, 1)
self.emit_defaults(name, prod.attributes, 1)

def visitSum(self, sum, name):
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' %
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, NULL, 0,' %
(name, name), 1)
self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False)
self.emit_type("%s_type" % name)
Expand All @@ -1067,35 +1160,33 @@ def visitSum(self, sum, name):
(name, name, len(sum.attributes)), 1)
else:
self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
self.emit_defaults(name, sum.attributes, 1)
for attribute in sum.attributes:
if attribute.opt:
self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1) {' %
(name, attribute.name), 1)
self.emit("return 0;", 2)
self.emit("}", 1)
simple = is_simple(sum)
for t in sum.types:
self.visitConstructor(t, name, simple)

def visitConstructor(self, cons, name, simple):
if cons.fields:
fields = cons.name+"_fields"
field_qualifiers = cons.name+"_field_qualifiers"
else:
fields = "NULL"
self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' %
(cons.name, cons.name, name, fields, len(cons.fields)), 1)
fields = field_qualifiers = "NULL"
self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %s, %d,' %
(cons.name, cons.name, name, fields, field_qualifiers, len(cons.fields)), 1)
self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False)
self.emit("if (!state->%s_type) return 0;" % cons.name, 1)
self.emit_type("%s_type" % cons.name)
self.emit_defaults(cons.name, cons.fields, 1)
if simple:
self.emit("state->%s_singleton = PyType_GenericNew((PyTypeObject *)"
"state->%s_type, NULL, NULL);" %
(cons.name, cons.name), 1)
self.emit("if (!state->%s_singleton) return 0;" % cons.name, 1)

def emit_defaults(self, name, fields, depth):
for field in fields:
if field.opt:
self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1)' %
(name, field.name), depth)
self.emit("return 0;", depth+1)


class ASTModuleVisitor(PickleVisitor):

Expand Down Expand Up @@ -1397,6 +1488,7 @@ def generate_module_def(f, mod):
state_strings = {
"ast",
"_fields",
"_field_qualifiers",
"__doc__",
"__dict__",
"__module__",
Expand Down
Loading