Skip to content

Commit 84e6c9b

Browse files
committed
Add support for conditionally defined overloads
1 parent 379622d commit 84e6c9b

File tree

2 files changed

+209
-1
lines changed

2 files changed

+209
-1
lines changed

mypy/fastparse.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from mypy import message_registry, errorcodes as codes
3838
from mypy.errors import Errors
3939
from mypy.options import Options
40-
from mypy.reachability import mark_block_unreachable
40+
from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable
4141

4242
try:
4343
# pull this into a final variable to make mypyc be quiet about the
@@ -444,12 +444,50 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
444444
ret = [] # type: List[Statement]
445445
current_overload = [] # type: List[OverloadPart]
446446
current_overload_name = None # type: Optional[str]
447+
last_if_stmt: Optional[IfStmt] = None
448+
last_if_overload: Optional[Union[Decorator, OverloadedFuncDef]] = None
447449
for stmt in stmts:
448450
if (current_overload_name is not None
449451
and isinstance(stmt, (Decorator, FuncDef))
450452
and stmt.name == current_overload_name):
453+
if last_if_overload is not None:
454+
if isinstance(last_if_overload, OverloadedFuncDef):
455+
current_overload.extend(last_if_overload.items)
456+
else:
457+
current_overload.append(last_if_overload)
458+
last_if_stmt, last_if_overload = None, None
451459
current_overload.append(stmt)
460+
elif (
461+
current_overload_name is not None
462+
and isinstance(stmt, IfStmt)
463+
and len(stmt.body[0].body) == 1
464+
and isinstance(
465+
stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef))
466+
and stmt.body[0].body[0].name == current_overload_name
467+
):
468+
# IfStmt only contains stmts relevant to current_overload.
469+
# Check if stmts are reachable and add them to current_overload,
470+
# otherwise skip IfStmt to allow subsequent overload
471+
# or function definitions.
472+
infer_reachability_of_if_statement(stmt, self.options)
473+
if stmt.body[0].is_unreachable is True:
474+
continue
475+
if last_if_overload is not None:
476+
if isinstance(last_if_overload, OverloadedFuncDef):
477+
current_overload.extend(last_if_overload.items)
478+
else:
479+
current_overload.append(last_if_overload)
480+
last_if_stmt, last_if_overload = None, None
481+
last_if_overload = None
482+
if isinstance(stmt.body[0].body[0], OverloadedFuncDef):
483+
current_overload.extend(stmt.body[0].body[0].items)
484+
else:
485+
current_overload.append(stmt.body[0].body[0])
452486
else:
487+
if last_if_stmt is not None:
488+
ret.append(last_if_stmt)
489+
last_if_stmt, last_if_overload = None, None
490+
453491
if len(current_overload) == 1:
454492
ret.append(current_overload[0])
455493
elif len(current_overload) > 1:
@@ -458,6 +496,19 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
458496
if isinstance(stmt, Decorator):
459497
current_overload = [stmt]
460498
current_overload_name = stmt.name
499+
elif (
500+
isinstance(stmt, IfStmt)
501+
and len(stmt.body[0].body) == 1
502+
and isinstance(
503+
stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
504+
and infer_reachability_of_if_statement(
505+
stmt, self.options
506+
) is None # type: ignore[func-returns-value]
507+
and stmt.body[0].is_unreachable is False
508+
):
509+
current_overload_name = stmt.body[0].body[0].name
510+
last_if_stmt = stmt
511+
last_if_overload = stmt.body[0].body[0]
461512
else:
462513
current_overload = []
463514
current_overload_name = None

test-data/unit/check-overloading.test

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5173,3 +5173,160 @@ def f2(g: G[A, Any]) -> A: ... # E: Overloaded function signatures 1 and 2 over
51735173
@overload
51745174
def f2(g: G[A, B], x: int = ...) -> B: ...
51755175
def f2(g: Any, x: int = ...) -> Any: ...
5176+
5177+
[case testOverloadIfBasic]
5178+
# flags: --always-true True
5179+
from typing import overload, Any
5180+
5181+
class A: ...
5182+
class B: ...
5183+
5184+
@overload
5185+
def f1(g: int) -> A: ...
5186+
if True:
5187+
@overload
5188+
def f1(g: str) -> B: ...
5189+
def f1(g: Any) -> Any: ...
5190+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5191+
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"
5192+
5193+
@overload
5194+
def f2(g: int) -> A: ...
5195+
@overload
5196+
def f2(g: bytes) -> A: ...
5197+
if not True:
5198+
@overload
5199+
def f2(g: str) -> B: ...
5200+
def f2(g: Any) -> Any: ...
5201+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5202+
reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \
5203+
# N: Possible overload variants: \
5204+
# N: def f2(g: int) -> A \
5205+
# N: def f2(g: bytes) -> A \
5206+
# N: Revealed type is "Any"
5207+
5208+
[case testOverloadIfSysVersion]
5209+
# flags: --python-version 3.9
5210+
from typing import overload, Any
5211+
import sys
5212+
5213+
class A: ...
5214+
class B: ...
5215+
5216+
@overload
5217+
def f1(g: int) -> A: ...
5218+
if sys.version_info >= (3, 9):
5219+
@overload
5220+
def f1(g: str) -> B: ...
5221+
def f1(g: Any) -> Any: ...
5222+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5223+
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"
5224+
5225+
@overload
5226+
def f2(g: int) -> A: ...
5227+
@overload
5228+
def f2(g: bytes) -> A: ...
5229+
if sys.version_info >= (3, 10):
5230+
@overload
5231+
def f2(g: str) -> B: ...
5232+
def f2(g: Any) -> Any: ...
5233+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5234+
reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \
5235+
# N: Possible overload variants: \
5236+
# N: def f2(g: int) -> A \
5237+
# N: def f2(g: bytes) -> A \
5238+
# N: Revealed type is "Any"
5239+
[builtins fixtures/tuple.pyi]
5240+
5241+
[case testOverloadIfMatching]
5242+
from typing import overload, Any
5243+
5244+
class A: ...
5245+
class B: ...
5246+
class C: ...
5247+
5248+
@overload
5249+
def f1(g: int) -> A: ...
5250+
if True:
5251+
# Some comment
5252+
@overload
5253+
def f1(g: str) -> B: ...
5254+
def f1(g: Any) -> Any: ...
5255+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5256+
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"
5257+
5258+
@overload
5259+
def f2(g: int) -> A: ...
5260+
if True:
5261+
@overload
5262+
def f2(g: bytes) -> B: ...
5263+
@overload
5264+
def f2(g: str) -> C: ...
5265+
def f2(g: Any) -> Any: ...
5266+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5267+
reveal_type(f2("Hello")) # N: Revealed type is "__main__.C"
5268+
5269+
@overload
5270+
def f3(g: int) -> A: ...
5271+
@overload
5272+
def f3(g: str) -> B: ...
5273+
if True:
5274+
def f3(g: Any) -> Any: ...
5275+
reveal_type(f3(42)) # N: Revealed type is "__main__.A"
5276+
reveal_type(f3("Hello")) # N: Revealed type is "__main__.B"
5277+
5278+
if True:
5279+
@overload
5280+
def f4(g: int) -> A: ...
5281+
@overload
5282+
def f4(g: str) -> B: ...
5283+
def f4(g: Any) -> Any: ...
5284+
reveal_type(f4(42)) # N: Revealed type is "__main__.A"
5285+
reveal_type(f4("Hello")) # N: Revealed type is "__main__.B"
5286+
5287+
if True:
5288+
# Some comment
5289+
@overload
5290+
def f5(g: int) -> A: ...
5291+
@overload
5292+
def f5(g: str) -> B: ...
5293+
def f5(g: Any) -> Any: ...
5294+
reveal_type(f5(42)) # N: Revealed type is "__main__.A"
5295+
reveal_type(f5("Hello")) # N: Revealed type is "__main__.B"
5296+
5297+
[case testOverloadIfNotMatching]
5298+
from typing import overload, Any
5299+
5300+
class A: ...
5301+
class B: ...
5302+
class C: ...
5303+
5304+
@overload # E: An overloaded function outside a stub file must have an implementation
5305+
def f1(g: int) -> A: ...
5306+
@overload
5307+
def f1(g: bytes) -> B: ...
5308+
if True:
5309+
@overload # E: Name "f1" already defined on line 7 \
5310+
# E: Single overload definition, multiple required
5311+
def f1(g: str) -> C: ...
5312+
pass # Some other action
5313+
def f1(g: Any) -> Any: ... # E: Name "f1" already defined on line 7
5314+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5315+
reveal_type(f1("Hello")) # E: No overload variant of "f1" matches argument type "str" \
5316+
# N: Possible overload variants: \
5317+
# N: def f1(g: int) -> A \
5318+
# N: def f1(g: bytes) -> B \
5319+
# N: Revealed type is "Any"
5320+
5321+
if True:
5322+
pass # Some other action
5323+
@overload # E: Single overload definition, multiple required
5324+
def f2(g: int) -> A: ...
5325+
@overload # E: Name "f2" already defined on line 21
5326+
def f2(g: bytes) -> B: ...
5327+
@overload
5328+
def f2(g: str) -> C: ...
5329+
def f2(g: Any) -> Any: ...
5330+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5331+
reveal_type(f2("Hello")) # N: Revealed type is "__main__.A" \
5332+
# E: Argument 1 to "f2" has incompatible type "str"; expected "int"

0 commit comments

Comments
 (0)