Skip to content

[partially defined] implement support for try statements #14114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 111 additions & 2 deletions mypy/partially_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
RefExpr,
ReturnStmt,
StarExpr,
TryStmt,
TupleExpr,
WhileStmt,
WithStmt,
Expand Down Expand Up @@ -66,6 +67,13 @@ def __init__(
self.must_be_defined = set(must_be_defined)
self.skipped = skipped

def copy(self) -> BranchState:
return BranchState(
must_be_defined=set(self.must_be_defined),
may_be_defined=set(self.may_be_defined),
skipped=self.skipped,
)


class BranchStatement:
def __init__(self, initial_state: BranchState) -> None:
Expand All @@ -77,6 +85,11 @@ def __init__(self, initial_state: BranchState) -> None:
)
]

def copy(self) -> BranchStatement:
result = BranchStatement(self.initial_state)
result.branches = [b.copy() for b in self.branches]
return result

def next_branch(self) -> None:
self.branches.append(
BranchState(
Expand All @@ -90,6 +103,11 @@ def record_definition(self, name: str) -> None:
self.branches[-1].must_be_defined.add(name)
self.branches[-1].may_be_defined.discard(name)

def delete_var(self, name: str) -> None:
assert len(self.branches) > 0
self.branches[-1].must_be_defined.discard(name)
self.branches[-1].may_be_defined.discard(name)

def record_nested_branch(self, state: BranchState) -> None:
assert len(self.branches) > 0
current_branch = self.branches[-1]
Expand Down Expand Up @@ -151,6 +169,11 @@ def __init__(self, stmts: list[BranchStatement]) -> None:
self.branch_stmts: list[BranchStatement] = stmts
self.undefined_refs: dict[str, set[NameExpr]] = {}

def copy(self) -> Scope:
result = Scope([s.copy() for s in self.branch_stmts])
result.undefined_refs = self.undefined_refs.copy()
return result

def record_undefined_ref(self, o: NameExpr) -> None:
if o.name not in self.undefined_refs:
self.undefined_refs[o.name] = set()
Expand All @@ -166,6 +189,15 @@ class DefinedVariableTracker:
def __init__(self) -> None:
# There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement.
self.scopes: list[Scope] = [Scope([BranchStatement(BranchState())])]
# disable_branch_skip is used to disable skipping a branch due to a return/raise/etc. This is useful
# in things like try/except/finally statements.
self.disable_branch_skip = False

def copy(self) -> DefinedVariableTracker:
result = DefinedVariableTracker()
result.scopes = [s.copy() for s in self.scopes]
result.disable_branch_skip = self.disable_branch_skip
return result

def _scope(self) -> Scope:
assert len(self.scopes) > 0
Expand Down Expand Up @@ -195,14 +227,19 @@ def end_branch_statement(self) -> None:

def skip_branch(self) -> None:
# Only skip branch if we're outside of "root" branch statement.
if len(self._scope().branch_stmts) > 1:
if len(self._scope().branch_stmts) > 1 and not self.disable_branch_skip:
self._scope().branch_stmts[-1].skip_branch()

def record_definition(self, name: str) -> None:
assert len(self.scopes) > 0
assert len(self.scopes[-1].branch_stmts) > 0
self._scope().branch_stmts[-1].record_definition(name)

def delete_var(self, name: str) -> None:
assert len(self.scopes) > 0
assert len(self.scopes[-1].branch_stmts) > 0
self._scope().branch_stmts[-1].delete_var(name)

def record_undefined_ref(self, o: NameExpr) -> None:
"""Records an undefined reference. These can later be retrieved via `pop_undefined_ref`."""
assert len(self.scopes) > 0
Expand Down Expand Up @@ -268,6 +305,7 @@ def __init__(
self.type_map = type_map
self.options = options
self.loops: list[Loop] = []
self.try_depth = 0
self.tracker = DefinedVariableTracker()
for name in implicit_module_attrs:
self.tracker.record_definition(name)
Expand Down Expand Up @@ -432,6 +470,75 @@ def visit_expression_stmt(self, o: ExpressionStmt) -> None:
self.tracker.skip_branch()
super().visit_expression_stmt(o)

def visit_try_stmt(self, o: TryStmt) -> None:
"""
Note that finding undefined vars in `finally` requires different handling from
the rest of the code. In particular, we want to disallow skipping branches due to jump
statements in except/else clauses for finally but not for other cases. Imagine a case like:
def f() -> int:
try:
x = 1
except:
# This jump statement needs to be handled differently depending on whether or
# not we're trying to process `finally` or not.
return 0
finally:
# `x` may be undefined here.
pass
# `x` is always defined here.
return x
"""
self.try_depth += 1
if o.finally_body is not None:
# In order to find undefined vars in `finally`, we need to
# process try/except with branch skipping disabled. However, for the rest of the code
# after finally, we need to process try/except with branch skipping enabled.
# Therefore, we need to process try/finally twice.
# Because processing is not idempotent, we should make a copy of the tracker.
old_tracker = self.tracker.copy()
self.tracker.disable_branch_skip = True
self.process_try_stmt(o)
self.tracker = old_tracker
self.process_try_stmt(o)
self.try_depth -= 1

def process_try_stmt(self, o: TryStmt) -> None:
"""
Processes try statement decomposing it into the following:
if ...:
body
else_body
elif ...:
except 1
elif ...:
except 2
else:
except n
finally
"""
self.tracker.start_branch_statement()
o.body.accept(self)
if o.else_body is not None:
o.else_body.accept(self)
if len(o.handlers) > 0:
assert len(o.handlers) == len(o.vars) == len(o.types)
for i in range(len(o.handlers)):
self.tracker.next_branch()
exc_type = o.types[i]
if exc_type is not None:
exc_type.accept(self)
var = o.vars[i]
if var is not None:
self.process_definition(var.name)
var.accept(self)
o.handlers[i].accept(self)
if var is not None:
self.tracker.delete_var(var.name)
self.tracker.end_branch_statement()

if o.finally_body is not None:
o.finally_body.accept(self)

def visit_while_stmt(self, o: WhileStmt) -> None:
o.expr.accept(self)
self.tracker.start_branch_statement()
Expand Down Expand Up @@ -478,7 +585,9 @@ def visit_name_expr(self, o: NameExpr) -> None:
self.tracker.record_definition(o.name)
elif self.tracker.is_defined_in_different_branch(o.name):
# A variable is defined in one branch but used in a different branch.
if self.loops:
if self.loops or self.try_depth > 0:
# If we're in a loop or in a try, we can't be sure that this variable
# is undefined. Report it as "may be undefined".
self.variable_may_be_undefined(o.name, o)
else:
self.var_used_before_def(o.name, o)
Expand Down
184 changes: 184 additions & 0 deletions test-data/unit/check-possibly-undefined.test
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,190 @@ def f3() -> None:
y = x
z = x # E: Name "x" may be undefined

[case testTryBasic]
# flags: --enable-error-code possibly-undefined --enable-error-code used-before-def
def f1() -> int:
try:
x = 1
except:
pass
return x # E: Name "x" may be undefined

def f2() -> int:
try:
pass
except:
x = 1
return x # E: Name "x" may be undefined

def f3() -> int:
try:
x = 1
except:
y = x # E: Name "x" may be undefined
return x # E: Name "x" may be undefined

def f4() -> int:
try:
x = 1
except:
return 0
return x

def f5() -> int:
try:
x = 1
except:
raise
return x

def f6() -> None:
try:
pass
except BaseException as exc:
x = exc # No error.
exc = BaseException()
# This case is covered by the other check, not by possibly undefined check.
y = exc # E: Trying to read deleted variable "exc"

def f7() -> int:
try:
if int():
x = 1
assert False
except:
pass
return x # E: Name "x" may be undefined
[builtins fixtures/exception.pyi]

[case testTryMultiExcept]
# flags: --enable-error-code possibly-undefined
def f1() -> int:
try:
x = 1
except BaseException:
x = 2
except:
x = 3
return x

def f2() -> int:
try:
x = 1
except BaseException:
pass
except:
x = 3
return x # E: Name "x" may be undefined
[builtins fixtures/exception.pyi]

[case testTryFinally]
# flags: --enable-error-code possibly-undefined --enable-error-code used-before-def
def f1() -> int:
try:
x = 1
finally:
x = 2
return x

def f2() -> int:
try:
pass
except:
pass
finally:
x = 2
return x

def f3() -> int:
try:
x = 1
except:
pass
finally:
y = x # E: Name "x" may be undefined
return x

def f4() -> int:
try:
x = 0
except BaseException:
raise
finally:
y = x # E: Name "x" may be undefined
return y

def f5() -> int:
try:
if int():
x = 1
else:
return 0
finally:
pass
return x # No error.

def f6() -> int:
try:
if int():
x = 1
else:
return 0
finally:
a = x # E: Name "x" may be undefined
return a
[builtins fixtures/exception.pyi]

[case testTryElse]
# flags: --enable-error-code possibly-undefined
def f1() -> int:
try:
return 0
except BaseException:
x = 1
else:
x = 2
finally:
y = x
return y

def f2() -> int:
try:
pass
except:
x = 1
else:
x = 2
return x

def f3() -> int:
try:
pass
except:
x = 1
else:
pass
return x # E: Name "x" may be undefined

def f4() -> int:
try:
x = 1
except:
x = 2
else:
pass
return x

def f5() -> int:
try:
pass
except:
x = 1
else:
return 1
return x
[builtins fixtures/exception.pyi]

[case testNoReturn]
# flags: --enable-error-code possibly-undefined

Expand Down