Skip to content

Add support for jump statements in partially defined vars check #13632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2337,13 +2337,15 @@ def type_check_second_pass(self) -> bool:
self.time_spent_us += time_spent_us(t0)
return result

def detect_partially_defined_vars(self) -> None:
def detect_partially_defined_vars(self, type_map: dict[Expression, Type]) -> None:
assert self.tree is not None, "Internal error: method must be called on parsed file only"
manager = self.manager
if manager.errors.is_error_code_enabled(codes.PARTIALLY_DEFINED):
manager.errors.set_file(self.xpath, self.tree.fullname, options=manager.options)
self.tree.accept(
PartiallyDefinedVariableVisitor(MessageBuilder(manager.errors, manager.modules))
PartiallyDefinedVariableVisitor(
MessageBuilder(manager.errors, manager.modules), type_map
)
)

def finish_passes(self) -> None:
Expand Down Expand Up @@ -3375,7 +3377,7 @@ def process_stale_scc(graph: Graph, scc: list[str], manager: BuildManager) -> No
graph[id].type_check_first_pass()
if not graph[id].type_checker().deferred_nodes:
unfinished_modules.discard(id)
graph[id].detect_partially_defined_vars()
graph[id].detect_partially_defined_vars(graph[id].type_map())
graph[id].finish_passes()

while unfinished_modules:
Expand All @@ -3384,7 +3386,7 @@ def process_stale_scc(graph: Graph, scc: list[str], manager: BuildManager) -> No
continue
if not graph[id].type_check_second_pass():
unfinished_modules.discard(id)
graph[id].detect_partially_defined_vars()
graph[id].detect_partially_defined_vars(graph[id].type_map())
graph[id].finish_passes()
for id in stale:
graph[id].generate_unused_ignore_notes()
Expand Down
161 changes: 111 additions & 50 deletions mypy/partially_defined.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,129 @@
from __future__ import annotations

from typing import NamedTuple

from mypy import checker
from mypy.messages import MessageBuilder
from mypy.nodes import (
AssertStmt,
AssignmentStmt,
BreakStmt,
ContinueStmt,
Expression,
ExpressionStmt,
ForStmt,
FuncDef,
FuncItem,
GeneratorExpr,
IfStmt,
ListExpr,
Lvalue,
NameExpr,
RaiseStmt,
ReturnStmt,
TupleExpr,
WhileStmt,
)
from mypy.traverser import TraverserVisitor
from mypy.traverser import ExtendedTraverserVisitor
from mypy.types import Type, UninhabitedType


class DefinedVars(NamedTuple):
"""DefinedVars contains information about variable definition at the end of a branching statement.
class BranchState:
"""BranchState contains information about variable definition at the end of a branching statement.
`if` and `match` are examples of branching statements.

`may_be_defined` contains variables that were defined in only some branches.
`must_be_defined` contains variables that were defined in all branches.
"""

may_be_defined: set[str]
must_be_defined: set[str]
def __init__(
self,
must_be_defined: set[str] | None = None,
may_be_defined: set[str] | None = None,
skipped: bool = False,
) -> None:
if may_be_defined is None:
may_be_defined = set()
if must_be_defined is None:
must_be_defined = set()

self.may_be_defined = set(may_be_defined)
self.must_be_defined = set(must_be_defined)
self.skipped = skipped


class BranchStatement:
def __init__(self, already_defined: DefinedVars) -> None:
self.already_defined = already_defined
self.defined_by_branch: list[DefinedVars] = [
DefinedVars(may_be_defined=set(), must_be_defined=set(already_defined.must_be_defined))
def __init__(self, initial_state: BranchState) -> None:
self.initial_state = initial_state
self.branches: list[BranchState] = [
BranchState(must_be_defined=self.initial_state.must_be_defined)
]

def next_branch(self) -> None:
self.defined_by_branch.append(
DefinedVars(
may_be_defined=set(), must_be_defined=set(self.already_defined.must_be_defined)
)
)
self.branches.append(BranchState(must_be_defined=self.initial_state.must_be_defined))

def record_definition(self, name: str) -> None:
assert len(self.defined_by_branch) > 0
self.defined_by_branch[-1].must_be_defined.add(name)
self.defined_by_branch[-1].may_be_defined.discard(name)

def record_nested_branch(self, vars: DefinedVars) -> None:
assert len(self.defined_by_branch) > 0
current_branch = self.defined_by_branch[-1]
current_branch.must_be_defined.update(vars.must_be_defined)
current_branch.may_be_defined.update(vars.may_be_defined)
assert len(self.branches) > 0
self.branches[-1].must_be_defined.add(name)
self.branches[-1].may_be_defined.discard(name)

def record_nested_branch(self, state: BranchState) -> None:
assert len(self.branches) > 0
current_branch = self.branches[-1]
if state.skipped:
current_branch.skipped = True
return
current_branch.must_be_defined.update(state.must_be_defined)
current_branch.may_be_defined.update(state.may_be_defined)
current_branch.may_be_defined.difference_update(current_branch.must_be_defined)

def skip_branch(self) -> None:
assert len(self.branches) > 0
self.branches[-1].skipped = True

def is_possibly_undefined(self, name: str) -> bool:
assert len(self.defined_by_branch) > 0
return name in self.defined_by_branch[-1].may_be_defined
assert len(self.branches) > 0
return name in self.branches[-1].may_be_defined

def done(self) -> DefinedVars:
assert len(self.defined_by_branch) > 0
if len(self.defined_by_branch) == 1:
# If there's only one branch, then we just return current.
# Note that this case is a different case when an empty branch is omitted (e.g. `if` without `else`).
return self.defined_by_branch[0]
def done(self) -> BranchState:
branches = [b for b in self.branches if not b.skipped]
if len(branches) == 0:
return BranchState(skipped=True)
if len(branches) == 1:
return branches[0]

# must_be_defined is a union of must_be_defined of all branches.
must_be_defined = set(self.defined_by_branch[0].must_be_defined)
for branch_vars in self.defined_by_branch[1:]:
must_be_defined.intersection_update(branch_vars.must_be_defined)
must_be_defined = set(branches[0].must_be_defined)
for b in branches[1:]:
must_be_defined.intersection_update(b.must_be_defined)
# may_be_defined are all variables that are not must be defined.
all_vars = set()
for branch_vars in self.defined_by_branch:
all_vars.update(branch_vars.may_be_defined)
all_vars.update(branch_vars.must_be_defined)
for b in branches:
all_vars.update(b.may_be_defined)
all_vars.update(b.must_be_defined)
may_be_defined = all_vars.difference(must_be_defined)
return DefinedVars(may_be_defined=may_be_defined, must_be_defined=must_be_defined)
return BranchState(may_be_defined=may_be_defined, must_be_defined=must_be_defined)


class DefinedVariableTracker:
"""DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor."""

def __init__(self) -> None:
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
self.scopes: list[list[BranchStatement]] = [
[BranchStatement(DefinedVars(may_be_defined=set(), must_be_defined=set()))]
]
self.scopes: list[list[BranchStatement]] = [[BranchStatement(BranchState())]]

def _scope(self) -> list[BranchStatement]:
assert len(self.scopes) > 0
return self.scopes[-1]

def enter_scope(self) -> None:
assert len(self._scope()) > 0
self.scopes.append([BranchStatement(self._scope()[-1].defined_by_branch[-1])])
self.scopes.append([BranchStatement(self._scope()[-1].branches[-1])])

def exit_scope(self) -> None:
self.scopes.pop()

def start_branch_statement(self) -> None:
assert len(self._scope()) > 0
self._scope().append(BranchStatement(self._scope()[-1].defined_by_branch[-1]))
self._scope().append(BranchStatement(self._scope()[-1].branches[-1]))

def next_branch(self) -> None:
assert len(self._scope()) > 1
Expand All @@ -113,6 +134,11 @@ def end_branch_statement(self) -> None:
result = self._scope().pop().done()
self._scope()[-1].record_nested_branch(result)

def skip_branch(self) -> None:
# Only skip branch if we're outside of "root" branch statement.
if len(self._scope()) > 1:
self._scope()[-1].skip_branch()

def record_declaration(self, name: str) -> None:
assert len(self.scopes) > 0
assert len(self.scopes[-1]) > 0
Expand All @@ -125,7 +151,7 @@ def is_possibly_undefined(self, name: str) -> bool:
return self._scope()[-1].is_possibly_undefined(name)


class PartiallyDefinedVariableVisitor(TraverserVisitor):
class PartiallyDefinedVariableVisitor(ExtendedTraverserVisitor):
"""Detect variables that are defined only part of the time.

This visitor detects the following case:
Expand All @@ -137,8 +163,9 @@ class PartiallyDefinedVariableVisitor(TraverserVisitor):
handled by the semantic analyzer.
"""

def __init__(self, msg: MessageBuilder) -> None:
def __init__(self, msg: MessageBuilder, type_map: dict[Expression, Type]) -> None:
self.msg = msg
self.type_map = type_map
self.tracker = DefinedVariableTracker()

def process_lvalue(self, lvalue: Lvalue) -> None:
Expand Down Expand Up @@ -175,6 +202,13 @@ def visit_func(self, o: FuncItem) -> None:
self.tracker.record_declaration(arg.variable.name)
super().visit_func(o)

def visit_generator_expr(self, o: GeneratorExpr) -> None:
self.tracker.enter_scope()
for idx in o.indices:
self.process_lvalue(idx)
super().visit_generator_expr(o)
self.tracker.exit_scope()

def visit_for_stmt(self, o: ForStmt) -> None:
o.expr.accept(self)
self.process_lvalue(o.index)
Expand All @@ -186,13 +220,40 @@ def visit_for_stmt(self, o: ForStmt) -> None:
o.else_body.accept(self)
self.tracker.end_branch_statement()

def visit_return_stmt(self, o: ReturnStmt) -> None:
super().visit_return_stmt(o)
self.tracker.skip_branch()

def visit_assert_stmt(self, o: AssertStmt) -> None:
super().visit_assert_stmt(o)
if checker.is_false_literal(o.expr):
self.tracker.skip_branch()

def visit_raise_stmt(self, o: RaiseStmt) -> None:
super().visit_raise_stmt(o)
self.tracker.skip_branch()

def visit_continue_stmt(self, o: ContinueStmt) -> None:
super().visit_continue_stmt(o)
self.tracker.skip_branch()

def visit_break_stmt(self, o: BreakStmt) -> None:
super().visit_break_stmt(o)
self.tracker.skip_branch()

def visit_expression_stmt(self, o: ExpressionStmt) -> None:
if isinstance(self.type_map.get(o.expr, None), UninhabitedType):
self.tracker.skip_branch()
super().visit_expression_stmt(o)

def visit_while_stmt(self, o: WhileStmt) -> None:
o.expr.accept(self)
self.tracker.start_branch_statement()
o.body.accept(self)
self.tracker.next_branch()
if o.else_body:
o.else_body.accept(self)
if not checker.is_true_literal(o.expr):
self.tracker.next_branch()
if o.else_body:
o.else_body.accept(self)
self.tracker.end_branch_statement()

def visit_name_expr(self, o: NameExpr) -> None:
Expand Down
2 changes: 1 addition & 1 deletion mypy/server/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def restore(ids: list[str]) -> None:
state.type_checker().reset()
state.type_check_first_pass()
state.type_check_second_pass()
state.detect_partially_defined_vars()
state.detect_partially_defined_vars(state.type_map())
t2 = time.time()
state.finish_passes()
t3 = time.time()
Expand Down
Loading