diff --git a/mypy/partially_defined.py b/mypy/partially_defined.py index 5f5253515b61..9b3e105f64ef 100644 --- a/mypy/partially_defined.py +++ b/mypy/partially_defined.py @@ -31,6 +31,7 @@ RefExpr, ReturnStmt, StarExpr, + TryStmt, TupleExpr, WhileStmt, WithStmt, @@ -66,6 +67,13 @@ def __init__( self.must_be_defined = set(must_be_defined) self.skipped = skipped + def copy(self) -> BranchState: + return BranchState( + must_be_defined=set(self.must_be_defined), + may_be_defined=set(self.may_be_defined), + skipped=self.skipped, + ) + class BranchStatement: def __init__(self, initial_state: BranchState) -> None: @@ -77,6 +85,11 @@ def __init__(self, initial_state: BranchState) -> None: ) ] + def copy(self) -> BranchStatement: + result = BranchStatement(self.initial_state) + result.branches = [b.copy() for b in self.branches] + return result + def next_branch(self) -> None: self.branches.append( BranchState( @@ -90,6 +103,11 @@ def record_definition(self, name: str) -> None: self.branches[-1].must_be_defined.add(name) self.branches[-1].may_be_defined.discard(name) + def delete_var(self, name: str) -> None: + assert len(self.branches) > 0 + self.branches[-1].must_be_defined.discard(name) + self.branches[-1].may_be_defined.discard(name) + def record_nested_branch(self, state: BranchState) -> None: assert len(self.branches) > 0 current_branch = self.branches[-1] @@ -151,6 +169,11 @@ def __init__(self, stmts: list[BranchStatement]) -> None: self.branch_stmts: list[BranchStatement] = stmts self.undefined_refs: dict[str, set[NameExpr]] = {} + def copy(self) -> Scope: + result = Scope([s.copy() for s in self.branch_stmts]) + result.undefined_refs = self.undefined_refs.copy() + return result + def record_undefined_ref(self, o: NameExpr) -> None: if o.name not in self.undefined_refs: self.undefined_refs[o.name] = set() @@ -166,6 +189,15 @@ class DefinedVariableTracker: def __init__(self) -> None: # There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement. self.scopes: list[Scope] = [Scope([BranchStatement(BranchState())])] + # disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful + # in things like try/except/finally statements. + self.disable_branch_skip = False + + def copy(self) -> DefinedVariableTracker: + result = DefinedVariableTracker() + result.scopes = [s.copy() for s in self.scopes] + result.disable_branch_skip = self.disable_branch_skip + return result def _scope(self) -> Scope: assert len(self.scopes) > 0 @@ -195,7 +227,7 @@ def end_branch_statement(self) -> None: def skip_branch(self) -> None: # Only skip branch if we're outside of "root" branch statement. - if len(self._scope().branch_stmts) > 1: + if len(self._scope().branch_stmts) > 1 and not self.disable_branch_skip: self._scope().branch_stmts[-1].skip_branch() def record_definition(self, name: str) -> None: @@ -203,6 +235,11 @@ def record_definition(self, name: str) -> None: assert len(self.scopes[-1].branch_stmts) > 0 self._scope().branch_stmts[-1].record_definition(name) + def delete_var(self, name: str) -> None: + assert len(self.scopes) > 0 + assert len(self.scopes[-1].branch_stmts) > 0 + self._scope().branch_stmts[-1].delete_var(name) + def record_undefined_ref(self, o: NameExpr) -> None: """Records an undefined reference. These can later be retrieved via `pop_undefined_ref`.""" assert len(self.scopes) > 0 @@ -268,6 +305,7 @@ def __init__( self.type_map = type_map self.options = options self.loops: list[Loop] = [] + self.try_depth = 0 self.tracker = DefinedVariableTracker() for name in implicit_module_attrs: self.tracker.record_definition(name) @@ -432,6 +470,75 @@ def visit_expression_stmt(self, o: ExpressionStmt) -> None: self.tracker.skip_branch() super().visit_expression_stmt(o) + def visit_try_stmt(self, o: TryStmt) -> None: + """ + Note that finding undefined vars in `finally` requires different handling from + the rest of the code. In particular, we want to disallow skipping branches due to jump + statements in except/else clauses for finally but not for other cases. Imagine a case like: + def f() -> int: + try: + x = 1 + except: + # This jump statement needs to be handled differently depending on whether or + # not we're trying to process `finally` or not. + return 0 + finally: + # `x` may be undefined here. + pass + # `x` is always defined here. + return x + """ + self.try_depth += 1 + if o.finally_body is not None: + # In order to find undefined vars in `finally`, we need to + # process try/except with branch skipping disabled. However, for the rest of the code + # after finally, we need to process try/except with branch skipping enabled. + # Therefore, we need to process try/finally twice. + # Because processing is not idempotent, we should make a copy of the tracker. + old_tracker = self.tracker.copy() + self.tracker.disable_branch_skip = True + self.process_try_stmt(o) + self.tracker = old_tracker + self.process_try_stmt(o) + self.try_depth -= 1 + + def process_try_stmt(self, o: TryStmt) -> None: + """ + Processes try statement decomposing it into the following: + if ...: + body + else_body + elif ...: + except 1 + elif ...: + except 2 + else: + except n + finally + """ + self.tracker.start_branch_statement() + o.body.accept(self) + if o.else_body is not None: + o.else_body.accept(self) + if len(o.handlers) > 0: + assert len(o.handlers) == len(o.vars) == len(o.types) + for i in range(len(o.handlers)): + self.tracker.next_branch() + exc_type = o.types[i] + if exc_type is not None: + exc_type.accept(self) + var = o.vars[i] + if var is not None: + self.process_definition(var.name) + var.accept(self) + o.handlers[i].accept(self) + if var is not None: + self.tracker.delete_var(var.name) + self.tracker.end_branch_statement() + + if o.finally_body is not None: + o.finally_body.accept(self) + def visit_while_stmt(self, o: WhileStmt) -> None: o.expr.accept(self) self.tracker.start_branch_statement() @@ -478,7 +585,9 @@ def visit_name_expr(self, o: NameExpr) -> None: self.tracker.record_definition(o.name) elif self.tracker.is_defined_in_different_branch(o.name): # A variable is defined in one branch but used in a different branch. - if self.loops: + if self.loops or self.try_depth > 0: + # If we're in a loop or in a try, we can't be sure that this variable + # is undefined. Report it as "may be undefined". self.variable_may_be_undefined(o.name, o) else: self.var_used_before_def(o.name, o) diff --git a/test-data/unit/check-possibly-undefined.test b/test-data/unit/check-possibly-undefined.test index d99943572a38..ee7020252de8 100644 --- a/test-data/unit/check-possibly-undefined.test +++ b/test-data/unit/check-possibly-undefined.test @@ -525,6 +525,190 @@ def f3() -> None: y = x z = x # E: Name "x" may be undefined +[case testTryBasic] +# flags: --enable-error-code possibly-undefined --enable-error-code used-before-def +def f1() -> int: + try: + x = 1 + except: + pass + return x # E: Name "x" may be undefined + +def f2() -> int: + try: + pass + except: + x = 1 + return x # E: Name "x" may be undefined + +def f3() -> int: + try: + x = 1 + except: + y = x # E: Name "x" may be undefined + return x # E: Name "x" may be undefined + +def f4() -> int: + try: + x = 1 + except: + return 0 + return x + +def f5() -> int: + try: + x = 1 + except: + raise + return x + +def f6() -> None: + try: + pass + except BaseException as exc: + x = exc # No error. + exc = BaseException() + # This case is covered by the other check, not by possibly undefined check. + y = exc # E: Trying to read deleted variable "exc" + +def f7() -> int: + try: + if int(): + x = 1 + assert False + except: + pass + return x # E: Name "x" may be undefined +[builtins fixtures/exception.pyi] + +[case testTryMultiExcept] +# flags: --enable-error-code possibly-undefined +def f1() -> int: + try: + x = 1 + except BaseException: + x = 2 + except: + x = 3 + return x + +def f2() -> int: + try: + x = 1 + except BaseException: + pass + except: + x = 3 + return x # E: Name "x" may be undefined +[builtins fixtures/exception.pyi] + +[case testTryFinally] +# flags: --enable-error-code possibly-undefined --enable-error-code used-before-def +def f1() -> int: + try: + x = 1 + finally: + x = 2 + return x + +def f2() -> int: + try: + pass + except: + pass + finally: + x = 2 + return x + +def f3() -> int: + try: + x = 1 + except: + pass + finally: + y = x # E: Name "x" may be undefined + return x + +def f4() -> int: + try: + x = 0 + except BaseException: + raise + finally: + y = x # E: Name "x" may be undefined + return y + +def f5() -> int: + try: + if int(): + x = 1 + else: + return 0 + finally: + pass + return x # No error. + +def f6() -> int: + try: + if int(): + x = 1 + else: + return 0 + finally: + a = x # E: Name "x" may be undefined + return a +[builtins fixtures/exception.pyi] + +[case testTryElse] +# flags: --enable-error-code possibly-undefined +def f1() -> int: + try: + return 0 + except BaseException: + x = 1 + else: + x = 2 + finally: + y = x + return y + +def f2() -> int: + try: + pass + except: + x = 1 + else: + x = 2 + return x + +def f3() -> int: + try: + pass + except: + x = 1 + else: + pass + return x # E: Name "x" may be undefined + +def f4() -> int: + try: + x = 1 + except: + x = 2 + else: + pass + return x + +def f5() -> int: + try: + pass + except: + x = 1 + else: + return 1 + return x +[builtins fixtures/exception.pyi] + [case testNoReturn] # flags: --enable-error-code possibly-undefined