diff --git a/mypy/build.py b/mypy/build.py index 62367c35915e..ba54c81845e0 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -2341,7 +2341,9 @@ def type_check_second_pass(self) -> bool: def detect_partially_defined_vars(self, type_map: dict[Expression, Type]) -> None: assert self.tree is not None, "Internal error: method must be called on parsed file only" manager = self.manager - if manager.errors.is_error_code_enabled(codes.PARTIALLY_DEFINED): + if manager.errors.is_error_code_enabled( + codes.PARTIALLY_DEFINED + ) or manager.errors.is_error_code_enabled(codes.USE_BEFORE_DEF): manager.errors.set_file(self.xpath, self.tree.fullname, options=manager.options) self.tree.accept( PartiallyDefinedVariableVisitor( diff --git a/mypy/errorcodes.py b/mypy/errorcodes.py index e1efc10b7a8b..1c15407a955b 100644 --- a/mypy/errorcodes.py +++ b/mypy/errorcodes.py @@ -192,6 +192,12 @@ def __str__(self) -> str: "General", default_enabled=False, ) +USE_BEFORE_DEF: Final[ErrorCode] = ErrorCode( + "use-before-def", + "Warn about variables that are used before they are defined", + "General", + default_enabled=False, +) # Syntax errors are often blocking. diff --git a/mypy/messages.py b/mypy/messages.py index 2f487972d647..85fa30512534 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1231,6 +1231,9 @@ def undefined_in_superclass(self, member: str, context: Context) -> None: def variable_may_be_undefined(self, name: str, context: Context) -> None: self.fail(f'Name "{name}" may be undefined', context, code=codes.PARTIALLY_DEFINED) + def var_used_before_def(self, name: str, context: Context) -> None: + self.fail(f'Name "{name}" is used before definition', context, code=codes.USE_BEFORE_DEF) + def first_argument_for_super_must_be_type(self, actual: Type, context: Context) -> None: actual = get_proper_type(actual) if isinstance(actual, Instance): diff --git a/mypy/partially_defined.py b/mypy/partially_defined.py index 5854036c0df3..70a454beae9c 100644 --- a/mypy/partially_defined.py +++ b/mypy/partially_defined.py @@ -1,6 +1,6 @@ from __future__ import annotations -from mypy import checker +from mypy import checker, errorcodes from mypy.messages import MessageBuilder from mypy.nodes import ( AssertStmt, @@ -93,10 +93,24 @@ def skip_branch(self) -> None: assert len(self.branches) > 0 self.branches[-1].skipped = True - def is_possibly_undefined(self, name: str) -> bool: + def is_partially_defined(self, name: str) -> bool: assert len(self.branches) > 0 return name in self.branches[-1].may_be_defined + def is_undefined(self, name: str) -> bool: + assert len(self.branches) > 0 + branch = self.branches[-1] + return name not in branch.may_be_defined and name not in branch.must_be_defined + + def is_defined_in_different_branch(self, name: str) -> bool: + assert len(self.branches) > 0 + if not self.is_undefined(name): + return False + for b in self.branches[: len(self.branches) - 1]: + if name in b.must_be_defined or name in b.may_be_defined: + return True + return False + def done(self) -> BranchState: branches = [b for b in self.branches if not b.skipped] if len(branches) == 0: @@ -117,62 +131,102 @@ def done(self) -> BranchState: return BranchState(may_be_defined=may_be_defined, must_be_defined=must_be_defined) +class Scope: + def __init__(self, stmts: list[BranchStatement]) -> None: + self.branch_stmts: list[BranchStatement] = stmts + self.undefined_refs: dict[str, set[NameExpr]] = {} + + def record_undefined_ref(self, o: NameExpr) -> None: + if o.name not in self.undefined_refs: + self.undefined_refs[o.name] = set() + self.undefined_refs[o.name].add(o) + + def pop_undefined_ref(self, name: str) -> set[NameExpr]: + return self.undefined_refs.pop(name, set()) + + class DefinedVariableTracker: """DefinedVariableTracker manages the state and scope for the UndefinedVariablesVisitor.""" def __init__(self) -> None: # There's always at least one scope. Within each scope, there's at least one "global" BranchingStatement. - self.scopes: list[list[BranchStatement]] = [[BranchStatement(BranchState())]] + self.scopes: list[Scope] = [Scope([BranchStatement(BranchState())])] - def _scope(self) -> list[BranchStatement]: + def _scope(self) -> Scope: assert len(self.scopes) > 0 return self.scopes[-1] def enter_scope(self) -> None: - assert len(self._scope()) > 0 - self.scopes.append([BranchStatement(self._scope()[-1].branches[-1])]) + assert len(self._scope().branch_stmts) > 0 + self.scopes.append(Scope([BranchStatement(self._scope().branch_stmts[-1].branches[-1])])) def exit_scope(self) -> None: self.scopes.pop() def start_branch_statement(self) -> None: - assert len(self._scope()) > 0 - self._scope().append(BranchStatement(self._scope()[-1].branches[-1])) + assert len(self._scope().branch_stmts) > 0 + self._scope().branch_stmts.append( + BranchStatement(self._scope().branch_stmts[-1].branches[-1]) + ) def next_branch(self) -> None: - assert len(self._scope()) > 1 - self._scope()[-1].next_branch() + assert len(self._scope().branch_stmts) > 1 + self._scope().branch_stmts[-1].next_branch() def end_branch_statement(self) -> None: - assert len(self._scope()) > 1 - result = self._scope().pop().done() - self._scope()[-1].record_nested_branch(result) + assert len(self._scope().branch_stmts) > 1 + result = self._scope().branch_stmts.pop().done() + self._scope().branch_stmts[-1].record_nested_branch(result) def skip_branch(self) -> None: # Only skip branch if we're outside of "root" branch statement. - if len(self._scope()) > 1: - self._scope()[-1].skip_branch() + if len(self._scope().branch_stmts) > 1: + self._scope().branch_stmts[-1].skip_branch() - def record_declaration(self, name: str) -> None: + def record_definition(self, name: str) -> None: + assert len(self.scopes) > 0 + assert len(self.scopes[-1].branch_stmts) > 0 + self._scope().branch_stmts[-1].record_definition(name) + + def record_undefined_ref(self, o: NameExpr) -> None: + """Records an undefined reference. These can later be retrieved via `pop_undefined_ref`.""" + assert len(self.scopes) > 0 + self._scope().record_undefined_ref(o) + + def pop_undefined_ref(self, name: str) -> set[NameExpr]: + """If name has previously been reported as undefined, the NameExpr that was called will be returned.""" assert len(self.scopes) > 0 - assert len(self.scopes[-1]) > 0 - self._scope()[-1].record_definition(name) + return self._scope().pop_undefined_ref(name) - def is_possibly_undefined(self, name: str) -> bool: - assert len(self._scope()) > 0 + def is_partially_defined(self, name: str) -> bool: + assert len(self._scope().branch_stmts) > 0 # A variable is undefined if it's in a set of `may_be_defined` but not in `must_be_defined`. - # Cases where a variable is not defined altogether are handled by semantic analyzer. - return self._scope()[-1].is_possibly_undefined(name) + return self._scope().branch_stmts[-1].is_partially_defined(name) + + def is_defined_in_different_branch(self, name: str) -> bool: + """This will return true if a variable is defined in a branch that's not the current branch.""" + assert len(self._scope().branch_stmts) > 0 + return self._scope().branch_stmts[-1].is_defined_in_different_branch(name) + + def is_undefined(self, name: str) -> bool: + assert len(self._scope().branch_stmts) > 0 + return self._scope().branch_stmts[-1].is_undefined(name) class PartiallyDefinedVariableVisitor(ExtendedTraverserVisitor): - """Detect variables that are defined only part of the time. + """Detects the following cases: + - A variable that's defined only part of the time. + - If a variable is used before definition - This visitor detects the following case: + An example of a partial definition: if foo(): x = 1 print(x) # Error: "x" may be undefined. + Example of a use before definition: + x = y + y: int = 2 + Note that this code does not detect variables not defined in any of the branches -- that is handled by the semantic analyzer. """ @@ -184,7 +238,11 @@ def __init__(self, msg: MessageBuilder, type_map: dict[Expression, Type]) -> Non def process_lvalue(self, lvalue: Lvalue | None) -> None: if isinstance(lvalue, NameExpr): - self.tracker.record_declaration(lvalue.name) + # Was this name previously used? If yes, it's a use-before-definition error. + refs = self.tracker.pop_undefined_ref(lvalue.name) + for ref in refs: + self.msg.var_used_before_def(lvalue.name, ref) + self.tracker.record_definition(lvalue.name) elif isinstance(lvalue, (ListExpr, TupleExpr)): for item in lvalue.items: self.process_lvalue(item) @@ -239,7 +297,7 @@ def visit_func_def(self, o: FuncDef) -> None: def visit_func(self, o: FuncItem) -> None: if o.arguments is not None: for arg in o.arguments: - self.tracker.record_declaration(arg.variable.name) + self.tracker.record_definition(arg.variable.name) super().visit_func(o) def visit_generator_expr(self, o: GeneratorExpr) -> None: @@ -314,10 +372,23 @@ def visit_starred_pattern(self, o: StarredPattern) -> None: super().visit_starred_pattern(o) def visit_name_expr(self, o: NameExpr) -> None: - if self.tracker.is_possibly_undefined(o.name): - self.msg.variable_may_be_undefined(o.name, o) + if self.tracker.is_partially_defined(o.name): + # A variable is only defined in some branches. + if self.msg.errors.is_error_code_enabled(errorcodes.PARTIALLY_DEFINED): + self.msg.variable_may_be_undefined(o.name, o) # We don't want to report the error on the same variable multiple times. - self.tracker.record_declaration(o.name) + self.tracker.record_definition(o.name) + elif self.tracker.is_defined_in_different_branch(o.name): + # A variable is defined in one branch but used in a different branch. + self.msg.var_used_before_def(o.name, o) + elif self.tracker.is_undefined(o.name): + # A variable is undefined. It could be due to two things: + # 1. A variable is just totally undefined + # 2. The variable is defined later in the code. + # Case (1) will be caught by semantic analyzer. Case (2) is a forward ref that should + # be caught by this visitor. Save the ref for later, so that if we see a definition, + # we know it's a use-before-definition scenario. + self.tracker.record_undefined_ref(o) super().visit_name_expr(o) def visit_with_stmt(self, o: WithStmt) -> None: diff --git a/test-data/unit/check-partially-defined.test b/test-data/unit/check-partially-defined.test index f6934fb142d1..c63023aa2746 100644 --- a/test-data/unit/check-partially-defined.test +++ b/test-data/unit/check-partially-defined.test @@ -386,3 +386,49 @@ else: z = 1 a = z [typing fixtures/typing-medium.pyi] + +[case testUseBeforeDef] +# flags: --enable-error-code use-before-def + +def f0() -> None: + x = y # E: Name "y" is used before definition + y: int = 1 + +def f1() -> None: + if int(): + x = 0 + else: + y = x # E: Name "x" is used before definition + z = x # E: Name "x" is used before definition + +def f2() -> None: + x = 1 + if int(): + x = 0 + else: + y = x # No error. + +def f3() -> None: + if int(): + pass + else: + # No use-before-def error. + y = z # E: Name "z" is not defined + + def inner2() -> None: + z = 0 + +def f4() -> None: + if int(): + pass + else: + y = z # E: Name "z" is used before definition + z: int = 2 + +def f5() -> None: + if int(): + pass + else: + y = z # E: Name "z" is used before definition + x = z # E: Name "z" is used before definition + z: int = 2