Skip to content

Commit 216a45b

Browse files
authored
Add support for jump statements in partially defined vars check (#13632)
This builds on #13601 to add support for statements like `continue`, `break`, `return`, `raise` in partially defined variables check. The simplest example is: ```python def f1() -> int: if int(): x = 1 else: return 0 return x ``` Previously, mypy would generate a false positive on the last line of example. See test cases for more details. Adding this support was relatively simple, given all the already existing code. Things that aren't supported yet: `match`, `with`, and detecting unreachable blocks. After this PR, when enabling this check on mypy itself, it generates 18 errors, all of them are potential bugs.
1 parent 0f17aff commit 216a45b

File tree

4 files changed

+300
-55
lines changed

4 files changed

+300
-55
lines changed

mypy/build.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -2331,13 +2331,15 @@ def type_check_second_pass(self) -> bool:
23312331
self.time_spent_us += time_spent_us(t0)
23322332
return result
23332333

2334-
def detect_partially_defined_vars(self) -> None:
2334+
def detect_partially_defined_vars(self, type_map: dict[Expression, Type]) -> None:
23352335
assert self.tree is not None, "Internal error: method must be called on parsed file only"
23362336
manager = self.manager
23372337
if manager.errors.is_error_code_enabled(codes.PARTIALLY_DEFINED):
23382338
manager.errors.set_file(self.xpath, self.tree.fullname, options=manager.options)
23392339
self.tree.accept(
2340-
PartiallyDefinedVariableVisitor(MessageBuilder(manager.errors, manager.modules))
2340+
PartiallyDefinedVariableVisitor(
2341+
MessageBuilder(manager.errors, manager.modules), type_map
2342+
)
23412343
)
23422344

23432345
def finish_passes(self) -> None:
@@ -3368,7 +3370,7 @@ def process_stale_scc(graph: Graph, scc: list[str], manager: BuildManager) -> No
33683370
graph[id].type_check_first_pass()
33693371
if not graph[id].type_checker().deferred_nodes:
33703372
unfinished_modules.discard(id)
3371-
graph[id].detect_partially_defined_vars()
3373+
graph[id].detect_partially_defined_vars(graph[id].type_map())
33723374
graph[id].finish_passes()
33733375

33743376
while unfinished_modules:
@@ -3377,7 +3379,7 @@ def process_stale_scc(graph: Graph, scc: list[str], manager: BuildManager) -> No
33773379
continue
33783380
if not graph[id].type_check_second_pass():
33793381
unfinished_modules.discard(id)
3380-
graph[id].detect_partially_defined_vars()
3382+
graph[id].detect_partially_defined_vars(graph[id].type_map())
33813383
graph[id].finish_passes()
33823384
for id in stale:
33833385
graph[id].generate_unused_ignore_notes()

mypy/partially_defined.py

+111-50
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,129 @@
11
from __future__ import annotations
22

3-
from typing import NamedTuple
4-
3+
from mypy import checker
54
from mypy.messages import MessageBuilder
65
from mypy.nodes import (
6+
AssertStmt,
77
AssignmentStmt,
8+
BreakStmt,
9+
ContinueStmt,
10+
Expression,
11+
ExpressionStmt,
812
ForStmt,
913
FuncDef,
1014
FuncItem,
15+
GeneratorExpr,
1116
IfStmt,
1217
ListExpr,
1318
Lvalue,
1419
NameExpr,
20+
RaiseStmt,
21+
ReturnStmt,
1522
TupleExpr,
1623
WhileStmt,
1724
)
18-
from mypy.traverser import TraverserVisitor
25+
from mypy.traverser import ExtendedTraverserVisitor
26+
from mypy.types import Type, UninhabitedType
1927

2028

21-
class DefinedVars(NamedTuple):
22-
"""DefinedVars contains information about variable definition at the end of a branching statement.
29+
class BranchState:
30+
"""BranchState contains information about variable definition at the end of a branching statement.
2331
`if` and `match` are examples of branching statements.
2432
2533
`may_be_defined` contains variables that were defined in only some branches.
2634
`must_be_defined` contains variables that were defined in all branches.
2735
"""
2836

29-
may_be_defined: set[str]
30-
must_be_defined: set[str]
37+
def __init__(
38+
self,
39+
must_be_defined: set[str] | None = None,
40+
may_be_defined: set[str] | None = None,
41+
skipped: bool = False,
42+
) -> None:
43+
if may_be_defined is None:
44+
may_be_defined = set()
45+
if must_be_defined is None:
46+
must_be_defined = set()
47+
48+
self.may_be_defined = set(may_be_defined)
49+
self.must_be_defined = set(must_be_defined)
50+
self.skipped = skipped
3151

3252

3353
class BranchStatement:
34-
def __init__(self, already_defined: DefinedVars) -> None:
35-
self.already_defined = already_defined
36-
self.defined_by_branch: list[DefinedVars] = [
37-
DefinedVars(may_be_defined=set(), must_be_defined=set(already_defined.must_be_defined))
54+
def __init__(self, initial_state: BranchState) -> None:
55+
self.initial_state = initial_state
56+
self.branches: list[BranchState] = [
57+
BranchState(must_be_defined=self.initial_state.must_be_defined)
3858
]
3959

4060
def next_branch(self) -> None:
41-
self.defined_by_branch.append(
42-
DefinedVars(
43-
may_be_defined=set(), must_be_defined=set(self.already_defined.must_be_defined)
44-
)
45-
)
61+
self.branches.append(BranchState(must_be_defined=self.initial_state.must_be_defined))
4662

4763
def record_definition(self, name: str) -> None:
48-
assert len(self.defined_by_branch) > 0
49-
self.defined_by_branch[-1].must_be_defined.add(name)
50-
self.defined_by_branch[-1].may_be_defined.discard(name)
51-
52-
def record_nested_branch(self, vars: DefinedVars) -> None:
53-
assert len(self.defined_by_branch) > 0
54-
current_branch = self.defined_by_branch[-1]
55-
current_branch.must_be_defined.update(vars.must_be_defined)
56-
current_branch.may_be_defined.update(vars.may_be_defined)
64+
assert len(self.branches) > 0
65+
self.branches[-1].must_be_defined.add(name)
66+
self.branches[-1].may_be_defined.discard(name)
67+
68+
def record_nested_branch(self, state: BranchState) -> None:
69+
assert len(self.branches) > 0
70+
current_branch = self.branches[-1]
71+
if state.skipped:
72+
current_branch.skipped = True
73+
return
74+
current_branch.must_be_defined.update(state.must_be_defined)
75+
current_branch.may_be_defined.update(state.may_be_defined)
5776
current_branch.may_be_defined.difference_update(current_branch.must_be_defined)
5877

78+
def skip_branch(self) -> None:
79+
assert len(self.branches) > 0
80+
self.branches[-1].skipped = True
81+
5982
def is_possibly_undefined(self, name: str) -> bool:
60-
assert len(self.defined_by_branch) > 0
61-
return name in self.defined_by_branch[-1].may_be_defined
83+
assert len(self.branches) > 0
84+
return name in self.branches[-1].may_be_defined
6285

63-
def done(self) -> DefinedVars:
64-
assert len(self.defined_by_branch) > 0
65-
if len(self.defined_by_branch) == 1:
66-
# If there's only one branch, then we just return current.
67-
# Note that this case is a different case when an empty branch is omitted (e.g. `if` without `else`).
68-
return self.defined_by_branch[0]
86+
def done(self) -> BranchState:
87+
branches = [b for b in self.branches if not b.skipped]
88+
if len(branches) == 0:
89+
return BranchState(skipped=True)
90+
if len(branches) == 1:
91+
return branches[0]
6992

7093
# must_be_defined is a union of must_be_defined of all branches.
71-
must_be_defined = set(self.defined_by_branch[0].must_be_defined)
72-
for branch_vars in self.defined_by_branch[1:]:
73-
must_be_defined.intersection_update(branch_vars.must_be_defined)
94+
must_be_defined = set(branches[0].must_be_defined)
95+
for b in branches[1:]:
96+
must_be_defined.intersection_update(b.must_be_defined)
7497
# may_be_defined are all variables that are not must be defined.
7598
all_vars = set()
76-
for branch_vars in self.defined_by_branch:
77-
all_vars.update(branch_vars.may_be_defined)
78-
all_vars.update(branch_vars.must_be_defined)
99+
for b in branches:
100+
all_vars.update(b.may_be_defined)
101+
all_vars.update(b.must_be_defined)
79102
may_be_defined = all_vars.difference(must_be_defined)
80-
return DefinedVars(may_be_defined=may_be_defined, must_be_defined=must_be_defined)
103+
return BranchState(may_be_defined=may_be_defined, must_be_defined=must_be_defined)
81104

82105

83106
class DefinedVariableTracker:
84107
"""DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor."""
85108

86109
def __init__(self) -> None:
87110
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
88-
self.scopes: list[list[BranchStatement]] = [
89-
[BranchStatement(DefinedVars(may_be_defined=set(), must_be_defined=set()))]
90-
]
111+
self.scopes: list[list[BranchStatement]] = [[BranchStatement(BranchState())]]
91112

92113
def _scope(self) -> list[BranchStatement]:
93114
assert len(self.scopes) > 0
94115
return self.scopes[-1]
95116

96117
def enter_scope(self) -> None:
97118
assert len(self._scope()) > 0
98-
self.scopes.append([BranchStatement(self._scope()[-1].defined_by_branch[-1])])
119+
self.scopes.append([BranchStatement(self._scope()[-1].branches[-1])])
99120

100121
def exit_scope(self) -> None:
101122
self.scopes.pop()
102123

103124
def start_branch_statement(self) -> None:
104125
assert len(self._scope()) > 0
105-
self._scope().append(BranchStatement(self._scope()[-1].defined_by_branch[-1]))
126+
self._scope().append(BranchStatement(self._scope()[-1].branches[-1]))
106127

107128
def next_branch(self) -> None:
108129
assert len(self._scope()) > 1
@@ -113,6 +134,11 @@ def end_branch_statement(self) -> None:
113134
result = self._scope().pop().done()
114135
self._scope()[-1].record_nested_branch(result)
115136

137+
def skip_branch(self) -> None:
138+
# Only skip branch if we're outside of "root" branch statement.
139+
if len(self._scope()) > 1:
140+
self._scope()[-1].skip_branch()
141+
116142
def record_declaration(self, name: str) -> None:
117143
assert len(self.scopes) > 0
118144
assert len(self.scopes[-1]) > 0
@@ -125,7 +151,7 @@ def is_possibly_undefined(self, name: str) -> bool:
125151
return self._scope()[-1].is_possibly_undefined(name)
126152

127153

128-
class PartiallyDefinedVariableVisitor(TraverserVisitor):
154+
class PartiallyDefinedVariableVisitor(ExtendedTraverserVisitor):
129155
"""Detect variables that are defined only part of the time.
130156
131157
This visitor detects the following case:
@@ -137,8 +163,9 @@ class PartiallyDefinedVariableVisitor(TraverserVisitor):
137163
handled by the semantic analyzer.
138164
"""
139165

140-
def __init__(self, msg: MessageBuilder) -> None:
166+
def __init__(self, msg: MessageBuilder, type_map: dict[Expression, Type]) -> None:
141167
self.msg = msg
168+
self.type_map = type_map
142169
self.tracker = DefinedVariableTracker()
143170

144171
def process_lvalue(self, lvalue: Lvalue) -> None:
@@ -175,6 +202,13 @@ def visit_func(self, o: FuncItem) -> None:
175202
self.tracker.record_declaration(arg.variable.name)
176203
super().visit_func(o)
177204

205+
def visit_generator_expr(self, o: GeneratorExpr) -> None:
206+
self.tracker.enter_scope()
207+
for idx in o.indices:
208+
self.process_lvalue(idx)
209+
super().visit_generator_expr(o)
210+
self.tracker.exit_scope()
211+
178212
def visit_for_stmt(self, o: ForStmt) -> None:
179213
o.expr.accept(self)
180214
self.process_lvalue(o.index)
@@ -186,13 +220,40 @@ def visit_for_stmt(self, o: ForStmt) -> None:
186220
o.else_body.accept(self)
187221
self.tracker.end_branch_statement()
188222

223+
def visit_return_stmt(self, o: ReturnStmt) -> None:
224+
super().visit_return_stmt(o)
225+
self.tracker.skip_branch()
226+
227+
def visit_assert_stmt(self, o: AssertStmt) -> None:
228+
super().visit_assert_stmt(o)
229+
if checker.is_false_literal(o.expr):
230+
self.tracker.skip_branch()
231+
232+
def visit_raise_stmt(self, o: RaiseStmt) -> None:
233+
super().visit_raise_stmt(o)
234+
self.tracker.skip_branch()
235+
236+
def visit_continue_stmt(self, o: ContinueStmt) -> None:
237+
super().visit_continue_stmt(o)
238+
self.tracker.skip_branch()
239+
240+
def visit_break_stmt(self, o: BreakStmt) -> None:
241+
super().visit_break_stmt(o)
242+
self.tracker.skip_branch()
243+
244+
def visit_expression_stmt(self, o: ExpressionStmt) -> None:
245+
if isinstance(self.type_map.get(o.expr, None), UninhabitedType):
246+
self.tracker.skip_branch()
247+
super().visit_expression_stmt(o)
248+
189249
def visit_while_stmt(self, o: WhileStmt) -> None:
190250
o.expr.accept(self)
191251
self.tracker.start_branch_statement()
192252
o.body.accept(self)
193-
self.tracker.next_branch()
194-
if o.else_body:
195-
o.else_body.accept(self)
253+
if not checker.is_true_literal(o.expr):
254+
self.tracker.next_branch()
255+
if o.else_body:
256+
o.else_body.accept(self)
196257
self.tracker.end_branch_statement()
197258

198259
def visit_name_expr(self, o: NameExpr) -> None:

mypy/server/update.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def restore(ids: list[str]) -> None:
651651
state.type_checker().reset()
652652
state.type_check_first_pass()
653653
state.type_check_second_pass()
654-
state.detect_partially_defined_vars()
654+
state.detect_partially_defined_vars(state.type_map())
655655
t2 = time.time()
656656
state.finish_passes()
657657
t3 = time.time()

0 commit comments

Comments
 (0)