Skip to content

Commit a8cb5e4

Browse files
authored
gh-129598: ast: allow multi stmts for ast single with ';' (#129620)
1 parent 20c5f96 commit a8cb5e4

File tree

3 files changed

+164
-30
lines changed

3 files changed

+164
-30
lines changed

Lib/ast.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,7 @@ def __init__(self):
674674
self._type_ignores = {}
675675
self._indent = 0
676676
self._in_try_star = False
677+
self._in_interactive = False
677678

678679
def interleave(self, inter, f, seq):
679680
"""Call f on each item in seq, calling inter() in between."""
@@ -702,11 +703,20 @@ def maybe_newline(self):
702703
if self._source:
703704
self.write("\n")
704705

705-
def fill(self, text=""):
706+
def maybe_semicolon(self):
707+
"""Adds a "; " delimiter if it isn't the start of generated source"""
708+
if self._source:
709+
self.write("; ")
710+
711+
def fill(self, text="", *, allow_semicolon=True):
706712
"""Indent a piece of text and append it, according to the current
707-
indentation level"""
708-
self.maybe_newline()
709-
self.write(" " * self._indent + text)
713+
indentation level, or only delineate with semicolon if applicable"""
714+
if self._in_interactive and not self._indent and allow_semicolon:
715+
self.maybe_semicolon()
716+
self.write(text)
717+
else:
718+
self.maybe_newline()
719+
self.write(" " * self._indent + text)
710720

711721
def write(self, *text):
712722
"""Add new source parts"""
@@ -812,8 +822,17 @@ def visit_Module(self, node):
812822
ignore.lineno: f"ignore{ignore.tag}"
813823
for ignore in node.type_ignores
814824
}
815-
self._write_docstring_and_traverse_body(node)
816-
self._type_ignores.clear()
825+
try:
826+
self._write_docstring_and_traverse_body(node)
827+
finally:
828+
self._type_ignores.clear()
829+
830+
def visit_Interactive(self, node):
831+
self._in_interactive = True
832+
try:
833+
self._write_docstring_and_traverse_body(node)
834+
finally:
835+
self._in_interactive = False
817836

818837
def visit_FunctionType(self, node):
819838
with self.delimit("(", ")"):
@@ -945,17 +964,17 @@ def visit_Raise(self, node):
945964
self.traverse(node.cause)
946965

947966
def do_visit_try(self, node):
948-
self.fill("try")
967+
self.fill("try", allow_semicolon=False)
949968
with self.block():
950969
self.traverse(node.body)
951970
for ex in node.handlers:
952971
self.traverse(ex)
953972
if node.orelse:
954-
self.fill("else")
973+
self.fill("else", allow_semicolon=False)
955974
with self.block():
956975
self.traverse(node.orelse)
957976
if node.finalbody:
958-
self.fill("finally")
977+
self.fill("finally", allow_semicolon=False)
959978
with self.block():
960979
self.traverse(node.finalbody)
961980

@@ -976,7 +995,7 @@ def visit_TryStar(self, node):
976995
self._in_try_star = prev_in_try_star
977996

978997
def visit_ExceptHandler(self, node):
979-
self.fill("except*" if self._in_try_star else "except")
998+
self.fill("except*" if self._in_try_star else "except", allow_semicolon=False)
980999
if node.type:
9811000
self.write(" ")
9821001
self.traverse(node.type)
@@ -989,9 +1008,9 @@ def visit_ExceptHandler(self, node):
9891008
def visit_ClassDef(self, node):
9901009
self.maybe_newline()
9911010
for deco in node.decorator_list:
992-
self.fill("@")
1011+
self.fill("@", allow_semicolon=False)
9931012
self.traverse(deco)
994-
self.fill("class " + node.name)
1013+
self.fill("class " + node.name, allow_semicolon=False)
9951014
if hasattr(node, "type_params"):
9961015
self._type_params_helper(node.type_params)
9971016
with self.delimit_if("(", ")", condition = node.bases or node.keywords):
@@ -1021,10 +1040,10 @@ def visit_AsyncFunctionDef(self, node):
10211040
def _function_helper(self, node, fill_suffix):
10221041
self.maybe_newline()
10231042
for deco in node.decorator_list:
1024-
self.fill("@")
1043+
self.fill("@", allow_semicolon=False)
10251044
self.traverse(deco)
10261045
def_str = fill_suffix + " " + node.name
1027-
self.fill(def_str)
1046+
self.fill(def_str, allow_semicolon=False)
10281047
if hasattr(node, "type_params"):
10291048
self._type_params_helper(node.type_params)
10301049
with self.delimit("(", ")"):
@@ -1075,54 +1094,54 @@ def visit_AsyncFor(self, node):
10751094
self._for_helper("async for ", node)
10761095

10771096
def _for_helper(self, fill, node):
1078-
self.fill(fill)
1097+
self.fill(fill, allow_semicolon=False)
10791098
self.set_precedence(_Precedence.TUPLE, node.target)
10801099
self.traverse(node.target)
10811100
self.write(" in ")
10821101
self.traverse(node.iter)
10831102
with self.block(extra=self.get_type_comment(node)):
10841103
self.traverse(node.body)
10851104
if node.orelse:
1086-
self.fill("else")
1105+
self.fill("else", allow_semicolon=False)
10871106
with self.block():
10881107
self.traverse(node.orelse)
10891108

10901109
def visit_If(self, node):
1091-
self.fill("if ")
1110+
self.fill("if ", allow_semicolon=False)
10921111
self.traverse(node.test)
10931112
with self.block():
10941113
self.traverse(node.body)
10951114
# collapse nested ifs into equivalent elifs.
10961115
while node.orelse and len(node.orelse) == 1 and isinstance(node.orelse[0], If):
10971116
node = node.orelse[0]
1098-
self.fill("elif ")
1117+
self.fill("elif ", allow_semicolon=False)
10991118
self.traverse(node.test)
11001119
with self.block():
11011120
self.traverse(node.body)
11021121
# final else
11031122
if node.orelse:
1104-
self.fill("else")
1123+
self.fill("else", allow_semicolon=False)
11051124
with self.block():
11061125
self.traverse(node.orelse)
11071126

11081127
def visit_While(self, node):
1109-
self.fill("while ")
1128+
self.fill("while ", allow_semicolon=False)
11101129
self.traverse(node.test)
11111130
with self.block():
11121131
self.traverse(node.body)
11131132
if node.orelse:
1114-
self.fill("else")
1133+
self.fill("else", allow_semicolon=False)
11151134
with self.block():
11161135
self.traverse(node.orelse)
11171136

11181137
def visit_With(self, node):
1119-
self.fill("with ")
1138+
self.fill("with ", allow_semicolon=False)
11201139
self.interleave(lambda: self.write(", "), self.traverse, node.items)
11211140
with self.block(extra=self.get_type_comment(node)):
11221141
self.traverse(node.body)
11231142

11241143
def visit_AsyncWith(self, node):
1125-
self.fill("async with ")
1144+
self.fill("async with ", allow_semicolon=False)
11261145
self.interleave(lambda: self.write(", "), self.traverse, node.items)
11271146
with self.block(extra=self.get_type_comment(node)):
11281147
self.traverse(node.body)
@@ -1264,7 +1283,7 @@ def visit_Name(self, node):
12641283
self.write(node.id)
12651284

12661285
def _write_docstring(self, node):
1267-
self.fill()
1286+
self.fill(allow_semicolon=False)
12681287
if node.kind == "u":
12691288
self.write("u")
12701289
self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES)
@@ -1558,7 +1577,7 @@ def visit_Slice(self, node):
15581577
self.traverse(node.step)
15591578

15601579
def visit_Match(self, node):
1561-
self.fill("match ")
1580+
self.fill("match ", allow_semicolon=False)
15621581
self.traverse(node.subject)
15631582
with self.block():
15641583
for case in node.cases:
@@ -1652,7 +1671,7 @@ def visit_withitem(self, node):
16521671
self.traverse(node.optional_vars)
16531672

16541673
def visit_match_case(self, node):
1655-
self.fill("case ")
1674+
self.fill("case ", allow_semicolon=False)
16561675
self.traverse(node.pattern)
16571676
if node.guard:
16581677
self.write(" if ")

Lib/test/test_unparse.py

Lines changed: 118 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,13 @@ def check_invalid(self, node, raises=ValueError):
142142
with self.subTest(node=node):
143143
self.assertRaises(raises, ast.unparse, node)
144144

145-
def get_source(self, code1, code2=None):
145+
def get_source(self, code1, code2=None, **kwargs):
146146
code2 = code2 or code1
147-
code1 = ast.unparse(ast.parse(code1))
147+
code1 = ast.unparse(ast.parse(code1, **kwargs))
148148
return code1, code2
149149

150-
def check_src_roundtrip(self, code1, code2=None):
151-
code1, code2 = self.get_source(code1, code2)
150+
def check_src_roundtrip(self, code1, code2=None, **kwargs):
151+
code1, code2 = self.get_source(code1, code2, **kwargs)
152152
with self.subTest(code1=code1, code2=code2):
153153
self.assertEqual(code2, code1)
154154

@@ -469,6 +469,120 @@ def test_type_ignore(self):
469469
):
470470
self.check_ast_roundtrip(statement, type_comments=True)
471471

472+
def test_unparse_interactive_semicolons(self):
473+
# gh-129598: Fix ast.unparse() when ast.Interactive contains multiple statements
474+
self.check_src_roundtrip("i = 1; 'expr'; raise Exception", mode='single')
475+
self.check_src_roundtrip("i: int = 1; j: float = 0; k += l", mode='single')
476+
combinable = (
477+
"'expr'",
478+
"(i := 1)",
479+
"import foo",
480+
"from foo import bar",
481+
"i = 1",
482+
"i += 1",
483+
"i: int = 1",
484+
"return i",
485+
"pass",
486+
"break",
487+
"continue",
488+
"del i",
489+
"assert i",
490+
"global i",
491+
"nonlocal j",
492+
"await i",
493+
"yield i",
494+
"yield from i",
495+
"raise i",
496+
"type t[T] = ...",
497+
"i",
498+
)
499+
for a in combinable:
500+
for b in combinable:
501+
self.check_src_roundtrip(f"{a}; {b}", mode='single')
502+
503+
def test_unparse_interactive_integrity_1(self):
504+
# rest of unparse_interactive_integrity tests just make sure mode='single' parse and unparse didn't break
505+
self.check_src_roundtrip(
506+
"if i:\n 'expr'\nelse:\n raise Exception",
507+
"if i:\n 'expr'\nelse:\n raise Exception",
508+
mode='single'
509+
)
510+
self.check_src_roundtrip(
511+
"@decorator1\n@decorator2\ndef func():\n 'docstring'\n i = 1; 'expr'; raise Exception",
512+
'''@decorator1\n@decorator2\ndef func():\n """docstring"""\n i = 1\n 'expr'\n raise Exception''',
513+
mode='single'
514+
)
515+
self.check_src_roundtrip(
516+
"@decorator1\n@decorator2\nclass cls:\n 'docstring'\n i = 1; 'expr'; raise Exception",
517+
'''@decorator1\n@decorator2\nclass cls:\n """docstring"""\n i = 1\n 'expr'\n raise Exception''',
518+
mode='single'
519+
)
520+
521+
def test_unparse_interactive_integrity_2(self):
522+
for statement in (
523+
"def x():\n pass",
524+
"def x(y):\n pass",
525+
"async def x():\n pass",
526+
"async def x(y):\n pass",
527+
"for x in y:\n pass",
528+
"async for x in y:\n pass",
529+
"with x():\n pass",
530+
"async with x():\n pass",
531+
"def f():\n pass",
532+
"def f(a):\n pass",
533+
"def f(b=2):\n pass",
534+
"def f(a, b):\n pass",
535+
"def f(a, b=2):\n pass",
536+
"def f(a=5, b=2):\n pass",
537+
"def f(*, a=1, b=2):\n pass",
538+
"def f(*, a=1, b):\n pass",
539+
"def f(*, a, b=2):\n pass",
540+
"def f(a, b=None, *, c, **kwds):\n pass",
541+
"def f(a=2, *args, c=5, d, **kwds):\n pass",
542+
"def f(*args, **kwargs):\n pass",
543+
"class cls:\n\n def f(self):\n pass",
544+
"class cls:\n\n def f(self, a):\n pass",
545+
"class cls:\n\n def f(self, b=2):\n pass",
546+
"class cls:\n\n def f(self, a, b):\n pass",
547+
"class cls:\n\n def f(self, a, b=2):\n pass",
548+
"class cls:\n\n def f(self, a=5, b=2):\n pass",
549+
"class cls:\n\n def f(self, *, a=1, b=2):\n pass",
550+
"class cls:\n\n def f(self, *, a=1, b):\n pass",
551+
"class cls:\n\n def f(self, *, a, b=2):\n pass",
552+
"class cls:\n\n def f(self, a, b=None, *, c, **kwds):\n pass",
553+
"class cls:\n\n def f(self, a=2, *args, c=5, d, **kwds):\n pass",
554+
"class cls:\n\n def f(self, *args, **kwargs):\n pass",
555+
):
556+
self.check_src_roundtrip(statement, mode='single')
557+
558+
def test_unparse_interactive_integrity_3(self):
559+
for statement in (
560+
"def x():",
561+
"def x(y):",
562+
"async def x():",
563+
"async def x(y):",
564+
"for x in y:",
565+
"async for x in y:",
566+
"with x():",
567+
"async with x():",
568+
"def f():",
569+
"def f(a):",
570+
"def f(b=2):",
571+
"def f(a, b):",
572+
"def f(a, b=2):",
573+
"def f(a=5, b=2):",
574+
"def f(*, a=1, b=2):",
575+
"def f(*, a=1, b):",
576+
"def f(*, a, b=2):",
577+
"def f(a, b=None, *, c, **kwds):",
578+
"def f(a=2, *args, c=5, d, **kwds):",
579+
"def f(*args, **kwargs):",
580+
):
581+
src = statement + '\n i=1;j=2'
582+
out = statement + '\n i = 1\n j = 2'
583+
584+
self.check_src_roundtrip(src, out, mode='single')
585+
472586

473587
class CosmeticTestCase(ASTTestCase):
474588
"""Test if there are cosmetic issues caused by unnecessary additions"""
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix :func:`ast.unparse` when :class:`ast.Interactive` contains multiple statements.

0 commit comments

Comments
 (0)