1
1
from __future__ import annotations
2
2
3
- from typing import NamedTuple
4
-
3
+ from mypy import checker
5
4
from mypy .messages import MessageBuilder
6
5
from mypy .nodes import (
6
+ AssertStmt ,
7
7
AssignmentStmt ,
8
+ BreakStmt ,
9
+ ContinueStmt ,
10
+ Expression ,
11
+ ExpressionStmt ,
8
12
ForStmt ,
9
13
FuncDef ,
10
14
FuncItem ,
15
+ GeneratorExpr ,
11
16
IfStmt ,
12
17
ListExpr ,
13
18
Lvalue ,
14
19
NameExpr ,
20
+ RaiseStmt ,
21
+ ReturnStmt ,
15
22
TupleExpr ,
16
23
WhileStmt ,
17
24
)
18
- from mypy .traverser import TraverserVisitor
25
+ from mypy .traverser import ExtendedTraverserVisitor
26
+ from mypy .types import Type , UninhabitedType
19
27
20
28
21
- class DefinedVars ( NamedTuple ) :
22
- """DefinedVars contains information about variable definition at the end of a branching statement.
29
+ class BranchState :
30
+ """BranchState contains information about variable definition at the end of a branching statement.
23
31
`if` and `match` are examples of branching statements.
24
32
25
33
`may_be_defined` contains variables that were defined in only some branches.
26
34
`must_be_defined` contains variables that were defined in all branches.
27
35
"""
28
36
29
- may_be_defined : set [str ]
30
- must_be_defined : set [str ]
37
+ def __init__ (
38
+ self ,
39
+ must_be_defined : set [str ] | None = None ,
40
+ may_be_defined : set [str ] | None = None ,
41
+ skipped : bool = False ,
42
+ ) -> None :
43
+ if may_be_defined is None :
44
+ may_be_defined = set ()
45
+ if must_be_defined is None :
46
+ must_be_defined = set ()
47
+
48
+ self .may_be_defined = set (may_be_defined )
49
+ self .must_be_defined = set (must_be_defined )
50
+ self .skipped = skipped
31
51
32
52
33
53
class BranchStatement :
34
- def __init__ (self , already_defined : DefinedVars ) -> None :
35
- self .already_defined = already_defined
36
- self .defined_by_branch : list [DefinedVars ] = [
37
- DefinedVars ( may_be_defined = set (), must_be_defined = set ( already_defined . must_be_defined ) )
54
+ def __init__ (self , initial_state : BranchState ) -> None :
55
+ self .initial_state = initial_state
56
+ self .branches : list [BranchState ] = [
57
+ BranchState ( must_be_defined = self . initial_state . must_be_defined )
38
58
]
39
59
40
60
def next_branch (self ) -> None :
41
- self .defined_by_branch .append (
42
- DefinedVars (
43
- may_be_defined = set (), must_be_defined = set (self .already_defined .must_be_defined )
44
- )
45
- )
61
+ self .branches .append (BranchState (must_be_defined = self .initial_state .must_be_defined ))
46
62
47
63
def record_definition (self , name : str ) -> None :
48
- assert len (self .defined_by_branch ) > 0
49
- self .defined_by_branch [- 1 ].must_be_defined .add (name )
50
- self .defined_by_branch [- 1 ].may_be_defined .discard (name )
51
-
52
- def record_nested_branch (self , vars : DefinedVars ) -> None :
53
- assert len (self .defined_by_branch ) > 0
54
- current_branch = self .defined_by_branch [- 1 ]
55
- current_branch .must_be_defined .update (vars .must_be_defined )
56
- current_branch .may_be_defined .update (vars .may_be_defined )
64
+ assert len (self .branches ) > 0
65
+ self .branches [- 1 ].must_be_defined .add (name )
66
+ self .branches [- 1 ].may_be_defined .discard (name )
67
+
68
+ def record_nested_branch (self , state : BranchState ) -> None :
69
+ assert len (self .branches ) > 0
70
+ current_branch = self .branches [- 1 ]
71
+ if state .skipped :
72
+ current_branch .skipped = True
73
+ return
74
+ current_branch .must_be_defined .update (state .must_be_defined )
75
+ current_branch .may_be_defined .update (state .may_be_defined )
57
76
current_branch .may_be_defined .difference_update (current_branch .must_be_defined )
58
77
78
+ def skip_branch (self ) -> None :
79
+ assert len (self .branches ) > 0
80
+ self .branches [- 1 ].skipped = True
81
+
59
82
def is_possibly_undefined (self , name : str ) -> bool :
60
- assert len (self .defined_by_branch ) > 0
61
- return name in self .defined_by_branch [- 1 ].may_be_defined
83
+ assert len (self .branches ) > 0
84
+ return name in self .branches [- 1 ].may_be_defined
62
85
63
- def done (self ) -> DefinedVars :
64
- assert len ( self .defined_by_branch ) > 0
65
- if len (self . defined_by_branch ) == 1 :
66
- # If there's only one branch, then we just return current.
67
- # Note that this case is a different case when an empty branch is omitted (e.g. `if` without `else`).
68
- return self . defined_by_branch [0 ]
86
+ def done (self ) -> BranchState :
87
+ branches = [ b for b in self .branches if not b . skipped ]
88
+ if len (branches ) == 0 :
89
+ return BranchState ( skipped = True )
90
+ if len ( branches ) == 1 :
91
+ return branches [0 ]
69
92
70
93
# must_be_defined is a union of must_be_defined of all branches.
71
- must_be_defined = set (self . defined_by_branch [0 ].must_be_defined )
72
- for branch_vars in self . defined_by_branch [1 :]:
73
- must_be_defined .intersection_update (branch_vars .must_be_defined )
94
+ must_be_defined = set (branches [0 ].must_be_defined )
95
+ for b in branches [1 :]:
96
+ must_be_defined .intersection_update (b .must_be_defined )
74
97
# may_be_defined are all variables that are not must be defined.
75
98
all_vars = set ()
76
- for branch_vars in self . defined_by_branch :
77
- all_vars .update (branch_vars .may_be_defined )
78
- all_vars .update (branch_vars .must_be_defined )
99
+ for b in branches :
100
+ all_vars .update (b .may_be_defined )
101
+ all_vars .update (b .must_be_defined )
79
102
may_be_defined = all_vars .difference (must_be_defined )
80
- return DefinedVars (may_be_defined = may_be_defined , must_be_defined = must_be_defined )
103
+ return BranchState (may_be_defined = may_be_defined , must_be_defined = must_be_defined )
81
104
82
105
83
106
class DefinedVariableTracker :
84
107
"""DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor."""
85
108
86
109
def __init__ (self ) -> None :
87
110
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
88
- self .scopes : list [list [BranchStatement ]] = [
89
- [BranchStatement (DefinedVars (may_be_defined = set (), must_be_defined = set ()))]
90
- ]
111
+ self .scopes : list [list [BranchStatement ]] = [[BranchStatement (BranchState ())]]
91
112
92
113
def _scope (self ) -> list [BranchStatement ]:
93
114
assert len (self .scopes ) > 0
94
115
return self .scopes [- 1 ]
95
116
96
117
def enter_scope (self ) -> None :
97
118
assert len (self ._scope ()) > 0
98
- self .scopes .append ([BranchStatement (self ._scope ()[- 1 ].defined_by_branch [- 1 ])])
119
+ self .scopes .append ([BranchStatement (self ._scope ()[- 1 ].branches [- 1 ])])
99
120
100
121
def exit_scope (self ) -> None :
101
122
self .scopes .pop ()
102
123
103
124
def start_branch_statement (self ) -> None :
104
125
assert len (self ._scope ()) > 0
105
- self ._scope ().append (BranchStatement (self ._scope ()[- 1 ].defined_by_branch [- 1 ]))
126
+ self ._scope ().append (BranchStatement (self ._scope ()[- 1 ].branches [- 1 ]))
106
127
107
128
def next_branch (self ) -> None :
108
129
assert len (self ._scope ()) > 1
@@ -113,6 +134,11 @@ def end_branch_statement(self) -> None:
113
134
result = self ._scope ().pop ().done ()
114
135
self ._scope ()[- 1 ].record_nested_branch (result )
115
136
137
+ def skip_branch (self ) -> None :
138
+ # Only skip branch if we're outside of "root" branch statement.
139
+ if len (self ._scope ()) > 1 :
140
+ self ._scope ()[- 1 ].skip_branch ()
141
+
116
142
def record_declaration (self , name : str ) -> None :
117
143
assert len (self .scopes ) > 0
118
144
assert len (self .scopes [- 1 ]) > 0
@@ -125,7 +151,7 @@ def is_possibly_undefined(self, name: str) -> bool:
125
151
return self ._scope ()[- 1 ].is_possibly_undefined (name )
126
152
127
153
128
- class PartiallyDefinedVariableVisitor (TraverserVisitor ):
154
+ class PartiallyDefinedVariableVisitor (ExtendedTraverserVisitor ):
129
155
"""Detect variables that are defined only part of the time.
130
156
131
157
This visitor detects the following case:
@@ -137,8 +163,9 @@ class PartiallyDefinedVariableVisitor(TraverserVisitor):
137
163
handled by the semantic analyzer.
138
164
"""
139
165
140
- def __init__ (self , msg : MessageBuilder ) -> None :
166
+ def __init__ (self , msg : MessageBuilder , type_map : dict [ Expression , Type ] ) -> None :
141
167
self .msg = msg
168
+ self .type_map = type_map
142
169
self .tracker = DefinedVariableTracker ()
143
170
144
171
def process_lvalue (self , lvalue : Lvalue ) -> None :
@@ -175,6 +202,13 @@ def visit_func(self, o: FuncItem) -> None:
175
202
self .tracker .record_declaration (arg .variable .name )
176
203
super ().visit_func (o )
177
204
205
+ def visit_generator_expr (self , o : GeneratorExpr ) -> None :
206
+ self .tracker .enter_scope ()
207
+ for idx in o .indices :
208
+ self .process_lvalue (idx )
209
+ super ().visit_generator_expr (o )
210
+ self .tracker .exit_scope ()
211
+
178
212
def visit_for_stmt (self , o : ForStmt ) -> None :
179
213
o .expr .accept (self )
180
214
self .process_lvalue (o .index )
@@ -186,13 +220,40 @@ def visit_for_stmt(self, o: ForStmt) -> None:
186
220
o .else_body .accept (self )
187
221
self .tracker .end_branch_statement ()
188
222
223
+ def visit_return_stmt (self , o : ReturnStmt ) -> None :
224
+ super ().visit_return_stmt (o )
225
+ self .tracker .skip_branch ()
226
+
227
+ def visit_assert_stmt (self , o : AssertStmt ) -> None :
228
+ super ().visit_assert_stmt (o )
229
+ if checker .is_false_literal (o .expr ):
230
+ self .tracker .skip_branch ()
231
+
232
+ def visit_raise_stmt (self , o : RaiseStmt ) -> None :
233
+ super ().visit_raise_stmt (o )
234
+ self .tracker .skip_branch ()
235
+
236
+ def visit_continue_stmt (self , o : ContinueStmt ) -> None :
237
+ super ().visit_continue_stmt (o )
238
+ self .tracker .skip_branch ()
239
+
240
+ def visit_break_stmt (self , o : BreakStmt ) -> None :
241
+ super ().visit_break_stmt (o )
242
+ self .tracker .skip_branch ()
243
+
244
+ def visit_expression_stmt (self , o : ExpressionStmt ) -> None :
245
+ if isinstance (self .type_map .get (o .expr , None ), UninhabitedType ):
246
+ self .tracker .skip_branch ()
247
+ super ().visit_expression_stmt (o )
248
+
189
249
def visit_while_stmt (self , o : WhileStmt ) -> None :
190
250
o .expr .accept (self )
191
251
self .tracker .start_branch_statement ()
192
252
o .body .accept (self )
193
- self .tracker .next_branch ()
194
- if o .else_body :
195
- o .else_body .accept (self )
253
+ if not checker .is_true_literal (o .expr ):
254
+ self .tracker .next_branch ()
255
+ if o .else_body :
256
+ o .else_body .accept (self )
196
257
self .tracker .end_branch_statement ()
197
258
198
259
def visit_name_expr (self , o : NameExpr ) -> None :
0 commit comments