diff --git a/mypy/build.py b/mypy/build.py index 1720eedaad10..b8d8e941aae8 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -2337,13 +2337,15 @@ def type_check_second_pass(self) -> bool: self.time_spent_us += time_spent_us(t0) return result - def detect_partially_defined_vars(self) -> None: + def detect_partially_defined_vars(self, type_map: dict[Expression, Type]) -> None: assert self.tree is not None, "Internal error: method must be called on parsed file only" manager = self.manager if manager.errors.is_error_code_enabled(codes.PARTIALLY_DEFINED): manager.errors.set_file(self.xpath, self.tree.fullname, options=manager.options) self.tree.accept( - PartiallyDefinedVariableVisitor(MessageBuilder(manager.errors, manager.modules)) + PartiallyDefinedVariableVisitor( + MessageBuilder(manager.errors, manager.modules), type_map + ) ) def finish_passes(self) -> None: @@ -3375,7 +3377,7 @@ def process_stale_scc(graph: Graph, scc: list[str], manager: BuildManager) -> No graph[id].type_check_first_pass() if not graph[id].type_checker().deferred_nodes: unfinished_modules.discard(id) - graph[id].detect_partially_defined_vars() + graph[id].detect_partially_defined_vars(graph[id].type_map()) graph[id].finish_passes() while unfinished_modules: @@ -3384,7 +3386,7 @@ def process_stale_scc(graph: Graph, scc: list[str], manager: BuildManager) -> No continue if not graph[id].type_check_second_pass(): unfinished_modules.discard(id) - graph[id].detect_partially_defined_vars() + graph[id].detect_partially_defined_vars(graph[id].type_map()) graph[id].finish_passes() for id in stale: graph[id].generate_unused_ignore_notes() diff --git a/mypy/partially_defined.py b/mypy/partially_defined.py index ac8d2f8d3c01..2f7e002dd2dd 100644 --- a/mypy/partially_defined.py +++ b/mypy/partially_defined.py @@ -1,83 +1,106 @@ from __future__ import annotations -from typing import NamedTuple - +from mypy import checker from mypy.messages import MessageBuilder from mypy.nodes import ( + AssertStmt, AssignmentStmt, + BreakStmt, + ContinueStmt, + Expression, + ExpressionStmt, ForStmt, FuncDef, FuncItem, + GeneratorExpr, IfStmt, ListExpr, Lvalue, NameExpr, + RaiseStmt, + ReturnStmt, TupleExpr, WhileStmt, ) -from mypy.traverser import TraverserVisitor +from mypy.traverser import ExtendedTraverserVisitor +from mypy.types import Type, UninhabitedType -class DefinedVars(NamedTuple): - """DefinedVars contains information about variable definition at the end of a branching statement. +class BranchState: + """BranchState contains information about variable definition at the end of a branching statement. `if` and `match` are examples of branching statements. `may_be_defined` contains variables that were defined in only some branches. `must_be_defined` contains variables that were defined in all branches. """ - may_be_defined: set[str] - must_be_defined: set[str] + def __init__( + self, + must_be_defined: set[str] | None = None, + may_be_defined: set[str] | None = None, + skipped: bool = False, + ) -> None: + if may_be_defined is None: + may_be_defined = set() + if must_be_defined is None: + must_be_defined = set() + + self.may_be_defined = set(may_be_defined) + self.must_be_defined = set(must_be_defined) + self.skipped = skipped class BranchStatement: - def __init__(self, already_defined: DefinedVars) -> None: - self.already_defined = already_defined - self.defined_by_branch: list[DefinedVars] = [ - DefinedVars(may_be_defined=set(), must_be_defined=set(already_defined.must_be_defined)) + def __init__(self, initial_state: BranchState) -> None: + self.initial_state = initial_state + self.branches: list[BranchState] = [ + BranchState(must_be_defined=self.initial_state.must_be_defined) ] def next_branch(self) -> None: - self.defined_by_branch.append( - DefinedVars( - may_be_defined=set(), must_be_defined=set(self.already_defined.must_be_defined) - ) - ) + self.branches.append(BranchState(must_be_defined=self.initial_state.must_be_defined)) def record_definition(self, name: str) -> None: - assert len(self.defined_by_branch) > 0 - self.defined_by_branch[-1].must_be_defined.add(name) - self.defined_by_branch[-1].may_be_defined.discard(name) - - def record_nested_branch(self, vars: DefinedVars) -> None: - assert len(self.defined_by_branch) > 0 - current_branch = self.defined_by_branch[-1] - current_branch.must_be_defined.update(vars.must_be_defined) - current_branch.may_be_defined.update(vars.may_be_defined) + assert len(self.branches) > 0 + self.branches[-1].must_be_defined.add(name) + self.branches[-1].may_be_defined.discard(name) + + def record_nested_branch(self, state: BranchState) -> None: + assert len(self.branches) > 0 + current_branch = self.branches[-1] + if state.skipped: + current_branch.skipped = True + return + current_branch.must_be_defined.update(state.must_be_defined) + current_branch.may_be_defined.update(state.may_be_defined) current_branch.may_be_defined.difference_update(current_branch.must_be_defined) + def skip_branch(self) -> None: + assert len(self.branches) > 0 + self.branches[-1].skipped = True + def is_possibly_undefined(self, name: str) -> bool: - assert len(self.defined_by_branch) > 0 - return name in self.defined_by_branch[-1].may_be_defined + assert len(self.branches) > 0 + return name in self.branches[-1].may_be_defined - def done(self) -> DefinedVars: - assert len(self.defined_by_branch) > 0 - if len(self.defined_by_branch) == 1: - # If there's only one branch, then we just return current. - # Note that this case is a different case when an empty branch is omitted (e.g. `if` without `else`). - return self.defined_by_branch[0] + def done(self) -> BranchState: + branches = [b for b in self.branches if not b.skipped] + if len(branches) == 0: + return BranchState(skipped=True) + if len(branches) == 1: + return branches[0] # must_be_defined is a union of must_be_defined of all branches. - must_be_defined = set(self.defined_by_branch[0].must_be_defined) - for branch_vars in self.defined_by_branch[1:]: - must_be_defined.intersection_update(branch_vars.must_be_defined) + must_be_defined = set(branches[0].must_be_defined) + for b in branches[1:]: + must_be_defined.intersection_update(b.must_be_defined) # may_be_defined are all variables that are not must be defined. all_vars = set() - for branch_vars in self.defined_by_branch: - all_vars.update(branch_vars.may_be_defined) - all_vars.update(branch_vars.must_be_defined) + for b in branches: + all_vars.update(b.may_be_defined) + all_vars.update(b.must_be_defined) may_be_defined = all_vars.difference(must_be_defined) - return DefinedVars(may_be_defined=may_be_defined, must_be_defined=must_be_defined) + return BranchState(may_be_defined=may_be_defined, must_be_defined=must_be_defined) class DefinedVariableTracker: @@ -85,9 +108,7 @@ class DefinedVariableTracker: def __init__(self) -> None: # There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement. - self.scopes: list[list[BranchStatement]] = [ - [BranchStatement(DefinedVars(may_be_defined=set(), must_be_defined=set()))] - ] + self.scopes: list[list[BranchStatement]] = [[BranchStatement(BranchState())]] def _scope(self) -> list[BranchStatement]: assert len(self.scopes) > 0 @@ -95,14 +116,14 @@ def _scope(self) -> list[BranchStatement]: def enter_scope(self) -> None: assert len(self._scope()) > 0 - self.scopes.append([BranchStatement(self._scope()[-1].defined_by_branch[-1])]) + self.scopes.append([BranchStatement(self._scope()[-1].branches[-1])]) def exit_scope(self) -> None: self.scopes.pop() def start_branch_statement(self) -> None: assert len(self._scope()) > 0 - self._scope().append(BranchStatement(self._scope()[-1].defined_by_branch[-1])) + self._scope().append(BranchStatement(self._scope()[-1].branches[-1])) def next_branch(self) -> None: assert len(self._scope()) > 1 @@ -113,6 +134,11 @@ def end_branch_statement(self) -> None: result = self._scope().pop().done() self._scope()[-1].record_nested_branch(result) + def skip_branch(self) -> None: + # Only skip branch if we're outside of "root" branch statement. + if len(self._scope()) > 1: + self._scope()[-1].skip_branch() + def record_declaration(self, name: str) -> None: assert len(self.scopes) > 0 assert len(self.scopes[-1]) > 0 @@ -125,7 +151,7 @@ def is_possibly_undefined(self, name: str) -> bool: return self._scope()[-1].is_possibly_undefined(name) -class PartiallyDefinedVariableVisitor(TraverserVisitor): +class PartiallyDefinedVariableVisitor(ExtendedTraverserVisitor): """Detect variables that are defined only part of the time. This visitor detects the following case: @@ -137,8 +163,9 @@ class PartiallyDefinedVariableVisitor(TraverserVisitor): handled by the semantic analyzer. """ - def __init__(self, msg: MessageBuilder) -> None: + def __init__(self, msg: MessageBuilder, type_map: dict[Expression, Type]) -> None: self.msg = msg + self.type_map = type_map self.tracker = DefinedVariableTracker() def process_lvalue(self, lvalue: Lvalue) -> None: @@ -175,6 +202,13 @@ def visit_func(self, o: FuncItem) -> None: self.tracker.record_declaration(arg.variable.name) super().visit_func(o) + def visit_generator_expr(self, o: GeneratorExpr) -> None: + self.tracker.enter_scope() + for idx in o.indices: + self.process_lvalue(idx) + super().visit_generator_expr(o) + self.tracker.exit_scope() + def visit_for_stmt(self, o: ForStmt) -> None: o.expr.accept(self) self.process_lvalue(o.index) @@ -186,13 +220,40 @@ def visit_for_stmt(self, o: ForStmt) -> None: o.else_body.accept(self) self.tracker.end_branch_statement() + def visit_return_stmt(self, o: ReturnStmt) -> None: + super().visit_return_stmt(o) + self.tracker.skip_branch() + + def visit_assert_stmt(self, o: AssertStmt) -> None: + super().visit_assert_stmt(o) + if checker.is_false_literal(o.expr): + self.tracker.skip_branch() + + def visit_raise_stmt(self, o: RaiseStmt) -> None: + super().visit_raise_stmt(o) + self.tracker.skip_branch() + + def visit_continue_stmt(self, o: ContinueStmt) -> None: + super().visit_continue_stmt(o) + self.tracker.skip_branch() + + def visit_break_stmt(self, o: BreakStmt) -> None: + super().visit_break_stmt(o) + self.tracker.skip_branch() + + def visit_expression_stmt(self, o: ExpressionStmt) -> None: + if isinstance(self.type_map.get(o.expr, None), UninhabitedType): + self.tracker.skip_branch() + super().visit_expression_stmt(o) + def visit_while_stmt(self, o: WhileStmt) -> None: o.expr.accept(self) self.tracker.start_branch_statement() o.body.accept(self) - self.tracker.next_branch() - if o.else_body: - o.else_body.accept(self) + if not checker.is_true_literal(o.expr): + self.tracker.next_branch() + if o.else_body: + o.else_body.accept(self) self.tracker.end_branch_statement() def visit_name_expr(self, o: NameExpr) -> None: diff --git a/mypy/server/update.py b/mypy/server/update.py index c9a8f7f0f0ee..65ce31da7c7a 100644 --- a/mypy/server/update.py +++ b/mypy/server/update.py @@ -651,7 +651,7 @@ def restore(ids: list[str]) -> None: state.type_checker().reset() state.type_check_first_pass() state.type_check_second_pass() - state.detect_partially_defined_vars() + state.detect_partially_defined_vars(state.type_map()) t2 = time.time() state.finish_passes() t3 = time.time() diff --git a/test-data/unit/check-partially-defined.test b/test-data/unit/check-partially-defined.test index a98bc7727575..c77be8148e8f 100644 --- a/test-data/unit/check-partially-defined.test +++ b/test-data/unit/check-partially-defined.test @@ -95,6 +95,13 @@ else: x = y + 2 +[case testGenerator] +# flags: --enable-error-code partially-defined +if int(): + a = 3 +s = [a + 1 for a in [1, 2, 3]] +x = a # E: Name "a" may be undefined + [case testScope] # flags: --enable-error-code partially-defined def foo() -> None: @@ -126,6 +133,12 @@ else: y = z # No error. +while True: + k = 1 + if int(): + break +y = k # No error. + [case testForLoop] # flags: --enable-error-code partially-defined for x in [1, 2, 3]: @@ -137,3 +150,172 @@ else: z = 2 a = z + y # E: Name "y" may be undefined + +[case testReturn] +# flags: --enable-error-code partially-defined +def f1() -> int: + if int(): + x = 1 + else: + return 0 + return x + +def f2() -> int: + if int(): + x = 1 + elif int(): + return 0 + else: + x = 2 + return x + +def f3() -> int: + if int(): + x = 1 + elif int(): + return 0 + else: + y = 2 + return x # E: Name "x" may be undefined + +def f4() -> int: + if int(): + x = 1 + elif int(): + return 0 + else: + return 0 + return x + +def f5() -> int: + # This is a test against crashes. + if int(): + return 1 + if int(): + return 2 + else: + return 3 + return 1 + +[case testAssert] +# flags: --enable-error-code partially-defined +def f1() -> int: + if int(): + x = 1 + else: + assert False, "something something" + return x + +def f2() -> int: + if int(): + x = 1 + elif int(): + assert False + else: + y = 2 + return x # E: Name "x" may be undefined + +[case testRaise] +# flags: --enable-error-code partially-defined +def f1() -> int: + if int(): + x = 1 + else: + raise BaseException("something something") + return x + +def f2() -> int: + if int(): + x = 1 + elif int(): + raise BaseException("something something") + else: + y = 2 + return x # E: Name "x" may be undefined +[builtins fixtures/exception.pyi] + +[case testContinue] +# flags: --enable-error-code partially-defined +def f1() -> int: + while int(): + if int(): + x = 1 + else: + continue + y = x + else: + x = 2 + return x + +def f2() -> int: + while int(): + if int(): + x = 1 + elif int(): + pass + else: + continue + y = x # E: Name "x" may be undefined + else: + x = 2 + return x # E: Name "x" may be undefined + +def f3() -> None: + while True: + if int(): + x = 2 + elif int(): + continue + else: + continue + y = x + +[case testBreak] +# flags: --enable-error-code partially-defined +def f1() -> None: + while int(): + if int(): + x = 1 + else: + break + y = x # No error -- x is always defined. + +def f2() -> None: + while int(): + if int(): + x = 1 + elif int(): + pass + else: + break + y = x # E: Name "x" may be undefined + +def f3() -> None: + while int(): + x = 1 + while int(): + if int(): + x = 2 + else: + break + y = x + z = x # E: Name "x" may be undefined + +[case testNoReturn] +# flags: --enable-error-code partially-defined + +from typing import NoReturn +def fail() -> NoReturn: + assert False + +def f() -> None: + if int(): + x = 1 + elif int(): + x = 2 + y = 3 + else: + # This has a NoReturn type, so we can skip it. + fail() + z = y # E: Name "y" may be undefined + z = x