Skip to content

Commit c610a31

Browse files
Support PEP 572 (#6899)
1 parent 5bb6796 commit c610a31

14 files changed

+218
-24
lines changed

mypy/checker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, PromoteExpr,
2525
Import, ImportFrom, ImportAll, ImportBase, TypeAlias,
2626
ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF,
27-
CONTRAVARIANT, COVARIANT, INVARIANT, TypeVarExpr,
27+
CONTRAVARIANT, COVARIANT, INVARIANT, TypeVarExpr, AssignmentExpr,
2828
is_final_node,
2929
)
3030
from mypy import nodes
@@ -2225,7 +2225,8 @@ def enter_final_context(self, is_final_def: bool) -> Iterator[None]:
22252225
finally:
22262226
self._is_final_def = old_ctx
22272227

2228-
def check_final(self, s: Union[AssignmentStmt, OperatorAssignmentStmt]) -> None:
2228+
def check_final(self,
2229+
s: Union[AssignmentStmt, OperatorAssignmentStmt, AssignmentExpr]) -> None:
22292230
"""Check if this assignment does not assign to a final attribute.
22302231
22312232
This function performs the check only for name assignments at module
@@ -2234,6 +2235,8 @@ def check_final(self, s: Union[AssignmentStmt, OperatorAssignmentStmt]) -> None:
22342235
"""
22352236
if isinstance(s, AssignmentStmt):
22362237
lvs = self.flatten_lvalues(s.lvalues)
2238+
elif isinstance(s, AssignmentExpr):
2239+
lvs = [s.target]
22372240
else:
22382241
lvs = [s.lvalue]
22392242
is_final_decl = s.is_final_def if isinstance(s, AssignmentStmt) else False

mypy/checkexpr.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
OpExpr, UnaryExpr, IndexExpr, CastExpr, RevealExpr, TypeApplication, ListExpr,
2727
TupleExpr, DictExpr, LambdaExpr, SuperExpr, SliceExpr, Context, Expression,
2828
ListComprehension, GeneratorExpr, SetExpr, MypyFile, Decorator,
29-
ConditionalExpr, ComparisonExpr, TempNode, SetComprehension,
29+
ConditionalExpr, ComparisonExpr, TempNode, SetComprehension, AssignmentExpr,
3030
DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr,
3131
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
3232
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode,
@@ -2544,6 +2544,12 @@ def check_list_multiply(self, e: OpExpr) -> Type:
25442544
e.method_type = method_type
25452545
return result
25462546

2547+
def visit_assignment_expr(self, e: AssignmentExpr) -> Type:
2548+
value = self.accept(e.value)
2549+
self.chk.check_assignment(e.target, e.value)
2550+
self.chk.check_final(e)
2551+
return value
2552+
25472553
def visit_unary_expr(self, e: UnaryExpr) -> Type:
25482554
"""Type check an unary operation ('not', '-', '+' or '~')."""
25492555
operand_type = self.accept(e.expr)

mypy/fastparse.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
TupleExpr, GeneratorExpr, ListComprehension, ListExpr, ConditionalExpr,
2121
DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr,
2222
FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr,
23-
UnaryExpr, LambdaExpr, ComparisonExpr,
23+
UnaryExpr, LambdaExpr, ComparisonExpr, AssignmentExpr,
2424
StarExpr, YieldFromExpr, NonlocalDecl, DictionaryComprehension,
2525
SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument,
2626
AwaitExpr, TempNode, Expression, Statement,
@@ -757,10 +757,6 @@ def visit_AugAssign(self, n: ast3.AugAssign) -> OperatorAssignmentStmt:
757757
self.visit(n.value))
758758
return self.set_line(s, n)
759759

760-
def visit_NamedExpr(self, n: NamedExpr) -> None:
761-
self.fail("assignment expressions are not yet supported", n.lineno, n.col_offset)
762-
return None
763-
764760
# For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment)
765761
def visit_For(self, n: ast3.For) -> ForStmt:
766762
target_type = self.translate_type_comment(n, n.type_comment)
@@ -902,6 +898,10 @@ def visit_Continue(self, n: ast3.Continue) -> ContinueStmt:
902898

903899
# --- expr ---
904900

901+
def visit_NamedExpr(self, n: NamedExpr) -> AssignmentExpr:
902+
s = AssignmentExpr(self.visit(n.target), self.visit(n.value))
903+
return self.set_line(s, n)
904+
905905
# BoolOp(boolop op, expr* values)
906906
def visit_BoolOp(self, n: ast3.BoolOp) -> OpExpr:
907907
# mypy translates (1 and 2 and 3) as (1 and (2 and 3))

mypy/literals.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
ConditionalExpr, EllipsisExpr, YieldFromExpr, YieldExpr, RevealExpr, SuperExpr,
99
TypeApplication, LambdaExpr, ListComprehension, SetComprehension, DictionaryComprehension,
1010
GeneratorExpr, BackquoteExpr, TypeVarExpr, TypeAliasExpr, NamedTupleExpr, EnumCallExpr,
11-
TypedDictExpr, NewTypeExpr, PromoteExpr, AwaitExpr, TempNode,
11+
TypedDictExpr, NewTypeExpr, PromoteExpr, AwaitExpr, TempNode, AssignmentExpr,
1212
)
1313
from mypy.visitor import ExpressionVisitor
1414

@@ -156,6 +156,9 @@ def visit_index_expr(self, e: IndexExpr) -> Optional[Key]:
156156
return ('Index', literal_hash(e.base), literal_hash(e.index))
157157
return None
158158

159+
def visit_assignment_expr(self, e: AssignmentExpr) -> None:
160+
return None
161+
159162
def visit_call_expr(self, e: CallExpr) -> None:
160163
return None
161164

mypy/nodes.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,6 +1593,17 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
15931593
return visitor.visit_unary_expr(self)
15941594

15951595

1596+
class AssignmentExpr(Expression):
1597+
"""Assignment expressions in Python 3.8+, like "a := 2"."""
1598+
def __init__(self, target: Expression, value: Expression) -> None:
1599+
super().__init__()
1600+
self.target = target
1601+
self.value = value
1602+
1603+
def accept(self, visitor: ExpressionVisitor[T]) -> T:
1604+
return visitor.visit_assignment_expr(self)
1605+
1606+
15961607
# Map from binary operator id to related method name (in Python 3).
15971608
op_methods = {
15981609
'+': '__add__',

mypy/semanal.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
PlaceholderNode, COVARIANT, CONTRAVARIANT, INVARIANT,
7575
nongen_builtins, get_member_expr_fullname, REVEAL_TYPE,
7676
REVEAL_LOCALS, is_final_node, TypedDictExpr, type_aliases_target_versions,
77-
EnumCallExpr, RUNTIME_PROTOCOL_DECOS, FakeExpression, Statement
77+
EnumCallExpr, RUNTIME_PROTOCOL_DECOS, FakeExpression, Statement, AssignmentExpr,
7878
)
7979
from mypy.tvar_scope import TypeVarScope
8080
from mypy.typevars import fill_typevars
@@ -177,6 +177,8 @@ class SemanticAnalyzer(NodeVisitor[None],
177177
nonlocal_decls = None # type: List[Set[str]]
178178
# Local names of function scopes; None for non-function scopes.
179179
locals = None # type: List[Optional[SymbolTable]]
180+
# Whether each scope is a comprehension scope.
181+
is_comprehension_stack = None # type: List[bool]
180182
# Nested block depths of scopes
181183
block_depth = None # type: List[int]
182184
# TypeInfo of directly enclosing class (or None)
@@ -242,6 +244,7 @@ def __init__(self,
242244
errors: Report analysis errors using this instance
243245
"""
244246
self.locals = [None]
247+
self.is_comprehension_stack = [False]
245248
# Saved namespaces from previous iteration. Every top-level function/method body is
246249
# analyzed in several iterations until all names are resolved. We need to save
247250
# the local namespaces for the top level function and all nested functions between
@@ -519,6 +522,12 @@ def file_context(self,
519522

520523
def visit_func_def(self, defn: FuncDef) -> None:
521524
self.statement = defn
525+
526+
# Visit default values because they may contain assignment expressions.
527+
for arg in defn.arguments:
528+
if arg.initializer:
529+
arg.initializer.accept(self)
530+
522531
defn.is_conditional = self.block_depth[-1] > 0
523532

524533
# Set full names even for those definitions that aren't added
@@ -1148,13 +1157,15 @@ def enter_class(self, info: TypeInfo) -> None:
11481157
# Remember previous active class
11491158
self.type_stack.append(self.type)
11501159
self.locals.append(None) # Add class scope
1160+
self.is_comprehension_stack.append(False)
11511161
self.block_depth.append(-1) # The class body increments this to 0
11521162
self.type = info
11531163

11541164
def leave_class(self) -> None:
11551165
""" Restore analyzer state. """
11561166
self.block_depth.pop()
11571167
self.locals.pop()
1168+
self.is_comprehension_stack.pop()
11581169
self.type = self.type_stack.pop()
11591170

11601171
def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None:
@@ -1858,6 +1869,10 @@ def visit_import_all(self, i: ImportAll) -> None:
18581869
# Assignment
18591870
#
18601871

1872+
def visit_assignment_expr(self, s: AssignmentExpr) -> None:
1873+
s.value.accept(self)
1874+
self.analyze_lvalue(s.target, escape_comprehensions=True)
1875+
18611876
def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
18621877
self.statement = s
18631878

@@ -2493,16 +2508,22 @@ def analyze_lvalue(self,
24932508
lval: Lvalue,
24942509
nested: bool = False,
24952510
explicit_type: bool = False,
2496-
is_final: bool = False) -> None:
2511+
is_final: bool = False,
2512+
escape_comprehensions: bool = False) -> None:
24972513
"""Analyze an lvalue or assignment target.
24982514
24992515
Args:
25002516
lval: The target lvalue
25012517
nested: If true, the lvalue is within a tuple or list lvalue expression
25022518
explicit_type: Assignment has type annotation
2519+
escape_comprehensions: If we are inside a comprehension, set the variable
2520+
in the enclosing scope instead. This implements
2521+
https://www.python.org/dev/peps/pep-0572/#scope-of-the-target
25032522
"""
2523+
if escape_comprehensions:
2524+
assert isinstance(lval, NameExpr), "assignment expression target must be NameExpr"
25042525
if isinstance(lval, NameExpr):
2505-
self.analyze_name_lvalue(lval, explicit_type, is_final)
2526+
self.analyze_name_lvalue(lval, explicit_type, is_final, escape_comprehensions)
25062527
elif isinstance(lval, MemberExpr):
25072528
self.analyze_member_lvalue(lval, explicit_type, is_final)
25082529
if explicit_type and not self.is_self_member_ref(lval):
@@ -2528,7 +2549,8 @@ def analyze_lvalue(self,
25282549
def analyze_name_lvalue(self,
25292550
lvalue: NameExpr,
25302551
explicit_type: bool,
2531-
is_final: bool) -> None:
2552+
is_final: bool,
2553+
escape_comprehensions: bool) -> None:
25322554
"""Analyze an lvalue that targets a name expression.
25332555
25342556
Arguments are similar to "analyze_lvalue".
@@ -2552,7 +2574,7 @@ def analyze_name_lvalue(self,
25522574
if (not existing or isinstance(existing.node, PlaceholderNode)) and not outer:
25532575
# Define new variable.
25542576
var = self.make_name_lvalue_var(lvalue, kind, not explicit_type)
2555-
added = self.add_symbol(name, var, lvalue)
2577+
added = self.add_symbol(name, var, lvalue, escape_comprehensions=escape_comprehensions)
25562578
# Only bind expression if we successfully added name to symbol table.
25572579
if added:
25582580
lvalue.is_new_def = True
@@ -4082,7 +4104,8 @@ def add_symbol(self,
40824104
context: Context,
40834105
module_public: bool = True,
40844106
module_hidden: bool = False,
4085-
can_defer: bool = True) -> bool:
4107+
can_defer: bool = True,
4108+
escape_comprehensions: bool = False) -> bool:
40864109
"""Add symbol to the currently active symbol table.
40874110
40884111
Generally additions to symbol table should go through this method or
@@ -4104,7 +4127,7 @@ def add_symbol(self,
41044127
node,
41054128
module_public=module_public,
41064129
module_hidden=module_hidden)
4107-
return self.add_symbol_table_node(name, symbol, context, can_defer)
4130+
return self.add_symbol_table_node(name, symbol, context, can_defer, escape_comprehensions)
41084131

41094132
def add_symbol_skip_local(self, name: str, node: SymbolNode) -> None:
41104133
"""Same as above, but skipping the local namespace.
@@ -4132,7 +4155,8 @@ def add_symbol_table_node(self,
41324155
name: str,
41334156
symbol: SymbolTableNode,
41344157
context: Optional[Context] = None,
4135-
can_defer: bool = True) -> bool:
4158+
can_defer: bool = True,
4159+
escape_comprehensions: bool = False) -> bool:
41364160
"""Add symbol table node to the currently active symbol table.
41374161
41384162
Return True if we actually added the symbol, or False if we refused
@@ -4151,7 +4175,7 @@ def add_symbol_table_node(self,
41514175
can_defer: if True, defer current target if adding a placeholder
41524176
context: error context (see above about None value)
41534177
"""
4154-
names = self.current_symbol_table()
4178+
names = self.current_symbol_table(escape_comprehensions=escape_comprehensions)
41554179
existing = names.get(name)
41564180
if isinstance(symbol.node, PlaceholderNode) and can_defer:
41574181
self.defer(context)
@@ -4379,13 +4403,16 @@ def enter(self, function: Union[FuncItem, GeneratorExpr, DictionaryComprehension
43794403
"""Enter a function, generator or comprehension scope."""
43804404
names = self.saved_locals.setdefault(function, SymbolTable())
43814405
self.locals.append(names)
4406+
is_comprehension = isinstance(function, (GeneratorExpr, DictionaryComprehension))
4407+
self.is_comprehension_stack.append(is_comprehension)
43824408
self.global_decls.append(set())
43834409
self.nonlocal_decls.append(set())
43844410
# -1 since entering block will increment this to 0.
43854411
self.block_depth.append(-1)
43864412

43874413
def leave(self) -> None:
43884414
self.locals.pop()
4415+
self.is_comprehension_stack.pop()
43894416
self.global_decls.pop()
43904417
self.nonlocal_decls.pop()
43914418
self.block_depth.pop()
@@ -4412,10 +4439,19 @@ def current_symbol_kind(self) -> int:
44124439
kind = GDEF
44134440
return kind
44144441

4415-
def current_symbol_table(self) -> SymbolTable:
4442+
def current_symbol_table(self, escape_comprehensions: bool = False) -> SymbolTable:
44164443
if self.is_func_scope():
44174444
assert self.locals[-1] is not None
4418-
names = self.locals[-1]
4445+
if escape_comprehensions:
4446+
for i, is_comprehension in enumerate(reversed(self.is_comprehension_stack)):
4447+
if not is_comprehension:
4448+
names = self.locals[-1 - i]
4449+
break
4450+
else:
4451+
assert False, "Should have at least one non-comprehension scope"
4452+
else:
4453+
names = self.locals[-1]
4454+
assert names is not None
44194455
elif self.type is not None:
44204456
names = self.type.names
44214457
else:

mypy/semanal_pass1.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def visit_import(self, node: Import) -> None:
9090

9191
def visit_if_stmt(self, s: IfStmt) -> None:
9292
infer_reachability_of_if_statement(s, self.options)
93+
for expr in s.expr:
94+
expr.accept(self)
9395
for node in s.body:
9496
node.accept(self)
9597
if s.else_body:

mypy/server/subexpr.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
SliceExpr, CastExpr, RevealExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr,
88
IndexExpr, GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension,
99
ConditionalExpr, TypeApplication, LambdaExpr, StarExpr, BackquoteExpr, AwaitExpr,
10+
AssignmentExpr,
1011
)
1112
from mypy.traverser import TraverserVisitor
1213

@@ -102,6 +103,10 @@ def visit_reveal_expr(self, e: RevealExpr) -> None:
102103
self.add(e)
103104
super().visit_reveal_expr(e)
104105

106+
def visit_assignment_expr(self, e: AssignmentExpr) -> None:
107+
self.add(e)
108+
super().visit_assignment_expr(e)
109+
105110
def visit_unary_expr(self, e: UnaryExpr) -> None:
106111
self.add(e)
107112
super().visit_unary_expr(e)

mypy/stats.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from mypy.nodes import (
1919
Expression, FuncDef, TypeApplication, AssignmentStmt, NameExpr, CallExpr, MypyFile,
2020
MemberExpr, OpExpr, ComparisonExpr, IndexExpr, UnaryExpr, YieldFromExpr, RefExpr, ClassDef,
21-
ImportFrom, Import, ImportAll, PassStmt, BreakStmt, ContinueStmt, StrExpr, BytesExpr,
22-
UnicodeExpr, IntExpr, FloatExpr, ComplexExpr, EllipsisExpr, ExpressionStmt, Node
21+
AssignmentExpr, ImportFrom, Import, ImportAll, PassStmt, BreakStmt, ContinueStmt, StrExpr,
22+
BytesExpr, UnicodeExpr, IntExpr, FloatExpr, ComplexExpr, EllipsisExpr, ExpressionStmt, Node
2323
)
2424
from mypy.util import correct_relative_import
2525
from mypy.argmap import map_formals_to_actuals
@@ -275,6 +275,10 @@ def visit_index_expr(self, o: IndexExpr) -> None:
275275
self.process_node(o)
276276
super().visit_index_expr(o)
277277

278+
def visit_assignment_expr(self, o: AssignmentExpr) -> None:
279+
self.process_node(o)
280+
super().visit_assignment_expr(o)
281+
278282
def visit_unary_expr(self, o: UnaryExpr) -> None:
279283
self.process_node(o)
280284
super().visit_unary_expr(o)

mypy/strconv.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,9 @@ def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> str:
425425
# REVEAL_LOCALS
426426
return self.dump([o.local_nodes], o)
427427

428+
def visit_assignment_expr(self, o: 'mypy.nodes.AssignmentExpr') -> str:
429+
return self.dump([o.target, o.value], o)
430+
428431
def visit_unary_expr(self, o: 'mypy.nodes.UnaryExpr') -> str:
429432
return self.dump([o.op, o.expr], o)
430433

mypy/traverser.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
ExpressionStmt, AssignmentStmt, OperatorAssignmentStmt, WhileStmt,
77
ForStmt, ReturnStmt, AssertStmt, DelStmt, IfStmt, RaiseStmt,
88
TryStmt, WithStmt, NameExpr, MemberExpr, OpExpr, SliceExpr, CastExpr, RevealExpr,
9-
UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr,
9+
UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, AssignmentExpr,
1010
GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension,
1111
ConditionalExpr, TypeApplication, ExecStmt, Import, ImportFrom,
1212
LambdaExpr, ComparisonExpr, OverloadedFuncDef, YieldFromExpr,
@@ -192,6 +192,10 @@ def visit_reveal_expr(self, o: RevealExpr) -> None:
192192
# RevealLocalsExpr doesn't have an inner expression
193193
pass
194194

195+
def visit_assignment_expr(self, o: AssignmentExpr) -> None:
196+
o.target.accept(self)
197+
o.value.accept(self)
198+
195199
def visit_unary_expr(self, o: UnaryExpr) -> None:
196200
o.expr.accept(self)
197201

0 commit comments

Comments
 (0)