Skip to content

Support for nonlocal declaration #504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 3, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions mypy/lex.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,29 @@ def __init__(self, string: str, type: int) -> None:
COMMENT_CONTEXT = 2


def lex(string: str, first_line: int = 1) -> List[Token]:
def lex(string: str, first_line: int = 1, pyversion: int = 3) -> List[Token]:
"""Analyze string and return an array of token objects.

The last token is always Eof.
"""
l = Lexer()
l = Lexer(pyversion)
l.lex(string, first_line)
return l.tok


# Reserved words (not including operators)
keywords = set([
keywords_common = set([
'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
'else', 'except', 'finally', 'from', 'for', 'global', 'if', 'import',
'lambda', 'pass', 'raise', 'return', 'try', 'while', 'with',
'yield'])

# Reserved words specific for Python version 2
keywords2 = set(['print'])

# Reserved words specific for Python version 3
keywords3 = set(['nonlocal'])

# Alphabetical operators (reserved words)
alpha_operators = set(['in', 'is', 'not', 'and', 'or'])

Expand Down Expand Up @@ -279,7 +285,7 @@ class Lexer:
# newlines within parentheses/brackets.
open_brackets = Undefined(List[str])

def __init__(self) -> None:
def __init__(self, pyversion: int = 3) -> None:
self.map = [self.unknown_character] * 256
self.tok = []
self.indents = [0]
Expand All @@ -302,6 +308,10 @@ def __init__(self) -> None:
('-+*/<>%&|^~=!,@', self.lex_misc)]:
for c in seq:
self.map[ord(c)] = method
if pyversion == 2:
self.keywords = keywords_common | keywords2
if pyversion == 3:
self.keywords = keywords_common | keywords3

def lex(self, s: str, first_line: int) -> None:
"""Lexically analyze a string, storing the tokens at the tok list."""
Expand Down Expand Up @@ -401,7 +411,7 @@ def lex_name(self) -> None:
Also deal with prefixed string literals such as r'...'.
"""
s = self.match(self.name_exp)
if s in keywords:
if s in self.keywords:
self.add_token(Keyword(s))
elif s in alpha_operators:
self.add_token(Op(s))
Expand Down
9 changes: 9 additions & 0 deletions mypy/noderepr.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ def __init__(self, global_tok: Any, names: List[Token],
self.br = br


class NonlocalDeclRepr:
def __init__(self, nonlocal_tok: Any, names: List[Token],
commas: List[Token], br: Any) -> None:
self.nonlocal_tok = nonlocal_tok
self.names = names
self.commas = commas
self.br = br


class ExpressionStmtRepr:
def __init__(self, br: Any) -> None:
self.br = br
Expand Down
12 changes: 12 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,18 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
return visitor.visit_global_decl(self)


class NonlocalDecl(Node):
"""Declaration nonlocal x, y, ..."""

names = Undefined(List[str])

def __init__(self, names: List[str]) -> None:
self.names = names

def accept(self, visitor: NodeVisitor[T]) -> T:
return visitor.visit_nonlocal_decl(self)


class Block(Node):
body = Undefined(List[Node])
# True if we can determine that this block is not executed. For example,
Expand Down
9 changes: 9 additions & 0 deletions mypy/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,15 @@ def visit_global_decl(self, o):
self.token(r.commas[i])
self.token(r.br)

def visit_nonlocal_decl(self, o):
r = o.repr
self.token(r.nonlocal_tok)
for i in range(len(r.names)):
self.token(r.names[i])
if i < len(r.commas):
self.token(r.commas[i])
self.token(r.br)

def visit_expression_stmt(self, o):
self.node(o.expr)
self.token(o.repr.br)
Expand Down
29 changes: 22 additions & 7 deletions mypy/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr,
FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr,
UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase, ComparisonExpr,
StarExpr, YieldFromStmt, YieldFromExpr
StarExpr, YieldFromStmt, YieldFromExpr, NonlocalDecl
)
from mypy import nodes
from mypy import noderepr
Expand Down Expand Up @@ -108,7 +108,7 @@ def __init__(self, fnam: str, errors: Errors, pyversion: int,
self.errors.set_file('<input>')

def parse(self, s: str) -> MypyFile:
self.tok = lex.lex(s)
self.tok = lex.lex(s, pyversion=self.pyversion)
self.ind = 0
self.imports = []
self.future_options = []
Expand Down Expand Up @@ -678,6 +678,8 @@ def parse_statement(self) -> Node:
stmt = self.parse_class_def()
elif ts == 'global':
stmt = self.parse_global_decl()
elif ts == 'nonlocal' and self.pyversion >= 3:
stmt = self.parse_nonlocal_decl()
elif ts == 'assert':
stmt = self.parse_assert_stmt()
elif ts == 'yield':
Expand Down Expand Up @@ -840,6 +842,23 @@ def parse_pass_stmt(self) -> PassStmt:

def parse_global_decl(self) -> GlobalDecl:
global_tok = self.expect('global')
name_toks, names, commas = self.parse_identifier_list()
br = self.expect_break()
node = GlobalDecl(names)
self.set_repr(node, noderepr.GlobalDeclRepr(global_tok, name_toks,
commas, br))
return node

def parse_nonlocal_decl(self) -> NonlocalDecl:
nonlocal_tok = self.expect('nonlocal')
name_toks, names, commas = self.parse_identifier_list()
br = self.expect_break()
node = NonlocalDecl(names)
self.set_repr(node, noderepr.NonlocalDeclRepr(nonlocal_tok, name_toks,
commas, br))
return node

def parse_identifier_list(self) -> Tuple[List[Token], List[str], List[Token]]:
names = List[str]()
name_toks = List[Token]()
commas = List[Token]()
Expand All @@ -850,11 +869,7 @@ def parse_global_decl(self) -> GlobalDecl:
if self.current_str() != ',':
break
commas.append(self.skip())
br = self.expect_break()
node = GlobalDecl(names)
self.set_repr(node, noderepr.GlobalDeclRepr(global_tok, name_toks,
commas, br))
return node
return name_toks, names, commas

def parse_while_stmt(self) -> WhileStmt:
is_error = False
Expand Down
45 changes: 40 additions & 5 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
FuncExpr, MDEF, FuncBase, Decorator, SetExpr, UndefinedExpr, TypeVarExpr,
StrExpr, PrintStmt, ConditionalExpr, DucktypeExpr, DisjointclassExpr,
ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases,
YieldFromStmt, YieldFromExpr, NamedTupleExpr
YieldFromStmt, YieldFromExpr, NamedTupleExpr, NonlocalDecl
)
from mypy.visitor import NodeVisitor
from mypy.traverser import TraverserVisitor
Expand Down Expand Up @@ -103,6 +103,8 @@ class SemanticAnalyzer(NodeVisitor):
globals = Undefined(SymbolTable)
# Names declared using "global" (separate set for each scope)
global_decls = Undefined(List[Set[str]])
# Names declated using "nonlocal" (separate set for each scope)
nonlocal_decls = Undefined(List[Set[str]])
# Local names of function scopes; None for non-function scopes.
locals = Undefined(List[SymbolTable])
# Nested block depths of scopes
Expand Down Expand Up @@ -764,7 +766,8 @@ def analyse_lvalue(self, lval: Node, nested: bool = False,
v = cast(Var, lval.node)
assert v.name() in self.globals
elif (self.is_func_scope() and lval.name not in self.locals[-1] and
lval.name not in self.global_decls[-1]):
lval.name not in self.global_decls[-1] and
lval.name not in self.nonlocal_decls[-1]):
# Define new local name.
v = Var(lval.name)
lval.node = v
Expand Down Expand Up @@ -1274,8 +1277,29 @@ def visit_del_stmt(self, s: DelStmt) -> None:
self.fail('Invalid delete target', s)

def visit_global_decl(self, g: GlobalDecl) -> None:
for n in g.names:
self.global_decls[-1].add(n)
for name in g.names:
if name in self.nonlocal_decls[-1]:
self.fail("Name '{}' is nonlocal and global".format(name), g)
self.global_decls[-1].add(name)

def visit_nonlocal_decl(self, d: NonlocalDecl) -> None:
if not self.is_func_scope():
self.fail("nonlocal declaration not allowed at module level", d)
else:
for name in d.names:
for table in reversed(self.locals[:-1]):
if table is not None and name in table:
break
else:
self.fail("No binding for nonlocal '{}' found".format(name), d)

if self.locals[-1] is not None and name in self.locals[-1]:
self.fail("Name '{}' is already defined in local "
"scope before nonlocal declaration".format(name), d)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Message capitalization


if name in self.global_decls[-1]:
self.fail("Name '{}' is nonlocal and global".format(name), d)
self.nonlocal_decls[-1].add(name)

def visit_print_stmt(self, s: PrintStmt) -> None:
for arg in s.args:
Expand Down Expand Up @@ -1530,13 +1554,21 @@ def visit_disjointclass_expr(self, expr: DisjointclassExpr) -> None:

def lookup(self, name: str, ctx: Context) -> SymbolTableNode:
"""Look up an unqualified name in all active namespaces."""
# 1. Name declared using 'global x' takes precedence
# 1a. Name declared using 'global x' takes precedence
if name in self.global_decls[-1]:
if name in self.globals:
return self.globals[name]
else:
self.name_not_defined(name, ctx)
return None
# 1b. Name declared using 'nonlocal x' takes precedence
if name in self.nonlocal_decls[-1]:
for table in reversed(self.locals[:-1]):
if table is not None and name in table:
return table[name]
else:
self.name_not_defined(name, ctx)
return None
# 2. Class attributes (if within class definition)
if self.is_class_scope() and name in self.type.names:
return self.type[name]
Expand Down Expand Up @@ -1607,10 +1639,12 @@ def qualified_name(self, n: str) -> str:
def enter(self) -> None:
self.locals.append(SymbolTable())
self.global_decls.append(set())
self.nonlocal_decls.append(set())

def leave(self) -> None:
self.locals.pop()
self.global_decls.pop()
self.nonlocal_decls.pop()

def is_func_scope(self) -> bool:
return self.locals[-1] is not None
Expand Down Expand Up @@ -1696,6 +1730,7 @@ def analyze(self, file: MypyFile, fnam: str, mod_id: str) -> None:
sem.errors.set_file(fnam)
sem.globals = SymbolTable()
sem.global_decls = [set()]
sem.nonlocal_decls = [set()]
sem.block_depth = [0]

defs = file.defs
Expand Down
3 changes: 3 additions & 0 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ def visit_var(self, o):
def visit_global_decl(self, o):
return self.dump([o.names], o)

def visit_nonlocal_decl(self, o):
return self.dump([o.names], o)

def visit_decorator(self, o):
return self.dump([o.var, o.decorators, o.func], o)

Expand Down
50 changes: 50 additions & 0 deletions mypy/test/data/check-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,7 @@ class B: pass
-- Type aliases
-- ------------


[case testSimpleTypeAlias]
import typing
foo = int
Expand All @@ -885,3 +886,52 @@ f('x') # E: Argument 1 to "f" has incompatible type "str"; expected "int"
[file m.py]
import typing
foo = int


-- nonlocal and global
-- -------------------


[case testTypeOfGlobalUsed]
import typing
g = A()
def f() -> None:
global g
g = B()

class A(): pass
class B(): pass
[out]
main: In function "f":
main, line 5: Incompatible types in assignment (expression has type "B", variable has type "A")

[case testTypeOfNonlocalUsed]
import typing
def f() -> None:
a = A()
def g() -> None:
nonlocal a
a = B()

class A(): pass
class B(): pass
[out]
main: In function "g":
main, line 6: Incompatible types in assignment (expression has type "B", variable has type "A")

[case testTypeOfOuterMostNonlocalUsed]
import typing
def f() -> None:
a = A()
def g() -> None:
a = B()
def h() -> None:
nonlocal a
a = A()
a = B()

class A(): pass
class B(): pass
[out]
main: In function "h":
main, line 8: Incompatible types in assignment (expression has type "A", variable has type "B")
16 changes: 16 additions & 0 deletions mypy/test/data/parse.test
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,22 @@ MypyFile:1(
x
y))))

[case testNonlocalDecl]
def f():
def g():
nonlocal x, y
[out]
MypyFile:1(
FuncDef:1(
f
Block:1(
FuncDef:2(
g
Block:2(
NonlocalDecl:3(
x
y))))))

[case testRaiseStatement]
raise foo
[out]
Expand Down
Loading