Skip to content

Commit ed32d9a

Browse files
committed
Merge pull request #504 from spkersten/nonlocal2
Support for nonlocal declaration
2 parents 29f1009 + 91e9f40 commit ed32d9a

13 files changed

+327
-18
lines changed

mypy/lex.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,23 +148,29 @@ def __init__(self, string: str, type: int) -> None:
148148
COMMENT_CONTEXT = 2
149149

150150

151-
def lex(string: str, first_line: int = 1) -> List[Token]:
151+
def lex(string: str, first_line: int = 1, pyversion: int = 3) -> List[Token]:
152152
"""Analyze string and return an array of token objects.
153153
154154
The last token is always Eof.
155155
"""
156-
l = Lexer()
156+
l = Lexer(pyversion)
157157
l.lex(string, first_line)
158158
return l.tok
159159

160160

161161
# Reserved words (not including operators)
162-
keywords = set([
162+
keywords_common = set([
163163
'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
164164
'else', 'except', 'finally', 'from', 'for', 'global', 'if', 'import',
165165
'lambda', 'pass', 'raise', 'return', 'try', 'while', 'with',
166166
'yield'])
167167

168+
# Reserved words specific for Python version 2
169+
keywords2 = set(['print'])
170+
171+
# Reserved words specific for Python version 3
172+
keywords3 = set(['nonlocal'])
173+
168174
# Alphabetical operators (reserved words)
169175
alpha_operators = set(['in', 'is', 'not', 'and', 'or'])
170176

@@ -279,7 +285,7 @@ class Lexer:
279285
# newlines within parentheses/brackets.
280286
open_brackets = Undefined(List[str])
281287

282-
def __init__(self) -> None:
288+
def __init__(self, pyversion: int = 3) -> None:
283289
self.map = [self.unknown_character] * 256
284290
self.tok = []
285291
self.indents = [0]
@@ -302,6 +308,10 @@ def __init__(self) -> None:
302308
('-+*/<>%&|^~=!,@', self.lex_misc)]:
303309
for c in seq:
304310
self.map[ord(c)] = method
311+
if pyversion == 2:
312+
self.keywords = keywords_common | keywords2
313+
if pyversion == 3:
314+
self.keywords = keywords_common | keywords3
305315

306316
def lex(self, s: str, first_line: int) -> None:
307317
"""Lexically analyze a string, storing the tokens at the tok list."""
@@ -401,7 +411,7 @@ def lex_name(self) -> None:
401411
Also deal with prefixed string literals such as r'...'.
402412
"""
403413
s = self.match(self.name_exp)
404-
if s in keywords:
414+
if s in self.keywords:
405415
self.add_token(Keyword(s))
406416
elif s in alpha_operators:
407417
self.add_token(Op(s))

mypy/noderepr.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,15 @@ def __init__(self, global_tok: Any, names: List[Token],
120120
self.br = br
121121

122122

123+
class NonlocalDeclRepr:
124+
def __init__(self, nonlocal_tok: Any, names: List[Token],
125+
commas: List[Token], br: Any) -> None:
126+
self.nonlocal_tok = nonlocal_tok
127+
self.names = names
128+
self.commas = commas
129+
self.br = br
130+
131+
123132
class ExpressionStmtRepr:
124133
def __init__(self, br: Any) -> None:
125134
self.br = br

mypy/nodes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,18 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
483483
return visitor.visit_global_decl(self)
484484

485485

486+
class NonlocalDecl(Node):
487+
"""Declaration nonlocal x, y, ..."""
488+
489+
names = Undefined(List[str])
490+
491+
def __init__(self, names: List[str]) -> None:
492+
self.names = names
493+
494+
def accept(self, visitor: NodeVisitor[T]) -> T:
495+
return visitor.visit_nonlocal_decl(self)
496+
497+
486498
class Block(Node):
487499
body = Undefined(List[Node])
488500
# True if we can determine that this block is not executed. For example,

mypy/output.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,15 @@ def visit_global_decl(self, o):
196196
self.token(r.commas[i])
197197
self.token(r.br)
198198

199+
def visit_nonlocal_decl(self, o):
200+
r = o.repr
201+
self.token(r.nonlocal_tok)
202+
for i in range(len(r.names)):
203+
self.token(r.names[i])
204+
if i < len(r.commas):
205+
self.token(r.commas[i])
206+
self.token(r.br)
207+
199208
def visit_expression_stmt(self, o):
200209
self.node(o.expr)
201210
self.token(o.repr.br)

mypy/parse.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr,
2525
FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr,
2626
UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase, ComparisonExpr,
27-
StarExpr, YieldFromStmt, YieldFromExpr
27+
StarExpr, YieldFromStmt, YieldFromExpr, NonlocalDecl
2828
)
2929
from mypy import nodes
3030
from mypy import noderepr
@@ -108,7 +108,7 @@ def __init__(self, fnam: str, errors: Errors, pyversion: int,
108108
self.errors.set_file('<input>')
109109

110110
def parse(self, s: str) -> MypyFile:
111-
self.tok = lex.lex(s)
111+
self.tok = lex.lex(s, pyversion=self.pyversion)
112112
self.ind = 0
113113
self.imports = []
114114
self.future_options = []
@@ -678,6 +678,8 @@ def parse_statement(self) -> Node:
678678
stmt = self.parse_class_def()
679679
elif ts == 'global':
680680
stmt = self.parse_global_decl()
681+
elif ts == 'nonlocal' and self.pyversion >= 3:
682+
stmt = self.parse_nonlocal_decl()
681683
elif ts == 'assert':
682684
stmt = self.parse_assert_stmt()
683685
elif ts == 'yield':
@@ -840,6 +842,23 @@ def parse_pass_stmt(self) -> PassStmt:
840842

841843
def parse_global_decl(self) -> GlobalDecl:
842844
global_tok = self.expect('global')
845+
name_toks, names, commas = self.parse_identifier_list()
846+
br = self.expect_break()
847+
node = GlobalDecl(names)
848+
self.set_repr(node, noderepr.GlobalDeclRepr(global_tok, name_toks,
849+
commas, br))
850+
return node
851+
852+
def parse_nonlocal_decl(self) -> NonlocalDecl:
853+
nonlocal_tok = self.expect('nonlocal')
854+
name_toks, names, commas = self.parse_identifier_list()
855+
br = self.expect_break()
856+
node = NonlocalDecl(names)
857+
self.set_repr(node, noderepr.NonlocalDeclRepr(nonlocal_tok, name_toks,
858+
commas, br))
859+
return node
860+
861+
def parse_identifier_list(self) -> Tuple[List[Token], List[str], List[Token]]:
843862
names = List[str]()
844863
name_toks = List[Token]()
845864
commas = List[Token]()
@@ -850,11 +869,7 @@ def parse_global_decl(self) -> GlobalDecl:
850869
if self.current_str() != ',':
851870
break
852871
commas.append(self.skip())
853-
br = self.expect_break()
854-
node = GlobalDecl(names)
855-
self.set_repr(node, noderepr.GlobalDeclRepr(global_tok, name_toks,
856-
commas, br))
857-
return node
872+
return name_toks, names, commas
858873

859874
def parse_while_stmt(self) -> WhileStmt:
860875
is_error = False

mypy/semanal.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
FuncExpr, MDEF, FuncBase, Decorator, SetExpr, UndefinedExpr, TypeVarExpr,
5858
StrExpr, PrintStmt, ConditionalExpr, DucktypeExpr, DisjointclassExpr,
5959
ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases,
60-
YieldFromStmt, YieldFromExpr, NamedTupleExpr
60+
YieldFromStmt, YieldFromExpr, NamedTupleExpr, NonlocalDecl
6161
)
6262
from mypy.visitor import NodeVisitor
6363
from mypy.traverser import TraverserVisitor
@@ -103,6 +103,8 @@ class SemanticAnalyzer(NodeVisitor):
103103
globals = Undefined(SymbolTable)
104104
# Names declared using "global" (separate set for each scope)
105105
global_decls = Undefined(List[Set[str]])
106+
# Names declated using "nonlocal" (separate set for each scope)
107+
nonlocal_decls = Undefined(List[Set[str]])
106108
# Local names of function scopes; None for non-function scopes.
107109
locals = Undefined(List[SymbolTable])
108110
# Nested block depths of scopes
@@ -764,7 +766,8 @@ def analyse_lvalue(self, lval: Node, nested: bool = False,
764766
v = cast(Var, lval.node)
765767
assert v.name() in self.globals
766768
elif (self.is_func_scope() and lval.name not in self.locals[-1] and
767-
lval.name not in self.global_decls[-1]):
769+
lval.name not in self.global_decls[-1] and
770+
lval.name not in self.nonlocal_decls[-1]):
768771
# Define new local name.
769772
v = Var(lval.name)
770773
lval.node = v
@@ -1274,8 +1277,29 @@ def visit_del_stmt(self, s: DelStmt) -> None:
12741277
self.fail('Invalid delete target', s)
12751278

12761279
def visit_global_decl(self, g: GlobalDecl) -> None:
1277-
for n in g.names:
1278-
self.global_decls[-1].add(n)
1280+
for name in g.names:
1281+
if name in self.nonlocal_decls[-1]:
1282+
self.fail("Name '{}' is nonlocal and global".format(name), g)
1283+
self.global_decls[-1].add(name)
1284+
1285+
def visit_nonlocal_decl(self, d: NonlocalDecl) -> None:
1286+
if not self.is_func_scope():
1287+
self.fail("nonlocal declaration not allowed at module level", d)
1288+
else:
1289+
for name in d.names:
1290+
for table in reversed(self.locals[:-1]):
1291+
if table is not None and name in table:
1292+
break
1293+
else:
1294+
self.fail("No binding for nonlocal '{}' found".format(name), d)
1295+
1296+
if self.locals[-1] is not None and name in self.locals[-1]:
1297+
self.fail("Name '{}' is already defined in local "
1298+
"scope before nonlocal declaration".format(name), d)
1299+
1300+
if name in self.global_decls[-1]:
1301+
self.fail("Name '{}' is nonlocal and global".format(name), d)
1302+
self.nonlocal_decls[-1].add(name)
12791303

12801304
def visit_print_stmt(self, s: PrintStmt) -> None:
12811305
for arg in s.args:
@@ -1530,13 +1554,21 @@ def visit_disjointclass_expr(self, expr: DisjointclassExpr) -> None:
15301554

15311555
def lookup(self, name: str, ctx: Context) -> SymbolTableNode:
15321556
"""Look up an unqualified name in all active namespaces."""
1533-
# 1. Name declared using 'global x' takes precedence
1557+
# 1a. Name declared using 'global x' takes precedence
15341558
if name in self.global_decls[-1]:
15351559
if name in self.globals:
15361560
return self.globals[name]
15371561
else:
15381562
self.name_not_defined(name, ctx)
15391563
return None
1564+
# 1b. Name declared using 'nonlocal x' takes precedence
1565+
if name in self.nonlocal_decls[-1]:
1566+
for table in reversed(self.locals[:-1]):
1567+
if table is not None and name in table:
1568+
return table[name]
1569+
else:
1570+
self.name_not_defined(name, ctx)
1571+
return None
15401572
# 2. Class attributes (if within class definition)
15411573
if self.is_class_scope() and name in self.type.names:
15421574
return self.type[name]
@@ -1607,10 +1639,12 @@ def qualified_name(self, n: str) -> str:
16071639
def enter(self) -> None:
16081640
self.locals.append(SymbolTable())
16091641
self.global_decls.append(set())
1642+
self.nonlocal_decls.append(set())
16101643

16111644
def leave(self) -> None:
16121645
self.locals.pop()
16131646
self.global_decls.pop()
1647+
self.nonlocal_decls.pop()
16141648

16151649
def is_func_scope(self) -> bool:
16161650
return self.locals[-1] is not None
@@ -1696,6 +1730,7 @@ def analyze(self, file: MypyFile, fnam: str, mod_id: str) -> None:
16961730
sem.errors.set_file(fnam)
16971731
sem.globals = SymbolTable()
16981732
sem.global_decls = [set()]
1733+
sem.nonlocal_decls = [set()]
16991734
sem.block_depth = [0]
17001735

17011736
defs = file.defs

mypy/strconv.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ def visit_var(self, o):
162162
def visit_global_decl(self, o):
163163
return self.dump([o.names], o)
164164

165+
def visit_nonlocal_decl(self, o):
166+
return self.dump([o.names], o)
167+
165168
def visit_decorator(self, o):
166169
return self.dump([o.var, o.decorators, o.func], o)
167170

mypy/test/data/check-statements.test

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,7 @@ class B: pass
859859
-- Type aliases
860860
-- ------------
861861

862+
862863
[case testSimpleTypeAlias]
863864
import typing
864865
foo = int
@@ -885,3 +886,52 @@ f('x') # E: Argument 1 to "f" has incompatible type "str"; expected "int"
885886
[file m.py]
886887
import typing
887888
foo = int
889+
890+
891+
-- nonlocal and global
892+
-- -------------------
893+
894+
895+
[case testTypeOfGlobalUsed]
896+
import typing
897+
g = A()
898+
def f() -> None:
899+
global g
900+
g = B()
901+
902+
class A(): pass
903+
class B(): pass
904+
[out]
905+
main: In function "f":
906+
main, line 5: Incompatible types in assignment (expression has type "B", variable has type "A")
907+
908+
[case testTypeOfNonlocalUsed]
909+
import typing
910+
def f() -> None:
911+
a = A()
912+
def g() -> None:
913+
nonlocal a
914+
a = B()
915+
916+
class A(): pass
917+
class B(): pass
918+
[out]
919+
main: In function "g":
920+
main, line 6: Incompatible types in assignment (expression has type "B", variable has type "A")
921+
922+
[case testTypeOfOuterMostNonlocalUsed]
923+
import typing
924+
def f() -> None:
925+
a = A()
926+
def g() -> None:
927+
a = B()
928+
def h() -> None:
929+
nonlocal a
930+
a = A()
931+
a = B()
932+
933+
class A(): pass
934+
class B(): pass
935+
[out]
936+
main: In function "h":
937+
main, line 8: Incompatible types in assignment (expression has type "A", variable has type "B")

mypy/test/data/parse.test

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,22 @@ MypyFile:1(
880880
x
881881
y))))
882882

883+
[case testNonlocalDecl]
884+
def f():
885+
def g():
886+
nonlocal x, y
887+
[out]
888+
MypyFile:1(
889+
FuncDef:1(
890+
f
891+
Block:1(
892+
FuncDef:2(
893+
g
894+
Block:2(
895+
NonlocalDecl:3(
896+
x
897+
y))))))
898+
883899
[case testRaiseStatement]
884900
raise foo
885901
[out]

0 commit comments

Comments
 (0)