Skip to content

Commit 399976b

Browse files
committed
Add support for conditionally defined overloads
1 parent b3b3242 commit 399976b

File tree

2 files changed

+209
-1
lines changed

2 files changed

+209
-1
lines changed

mypy/fastparse.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from mypy import message_registry, errorcodes as codes
3939
from mypy.errors import Errors
4040
from mypy.options import Options
41-
from mypy.reachability import mark_block_unreachable
41+
from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable
4242

4343
try:
4444
# pull this into a final variable to make mypyc be quiet about the
@@ -445,12 +445,50 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
445445
ret: List[Statement] = []
446446
current_overload: List[OverloadPart] = []
447447
current_overload_name: Optional[str] = None
448+
last_if_stmt: Optional[IfStmt] = None
449+
last_if_overload: Optional[Union[Decorator, OverloadedFuncDef]] = None
448450
for stmt in stmts:
449451
if (current_overload_name is not None
450452
and isinstance(stmt, (Decorator, FuncDef))
451453
and stmt.name == current_overload_name):
454+
if last_if_overload is not None:
455+
if isinstance(last_if_overload, OverloadedFuncDef):
456+
current_overload.extend(last_if_overload.items)
457+
else:
458+
current_overload.append(last_if_overload)
459+
last_if_stmt, last_if_overload = None, None
452460
current_overload.append(stmt)
461+
elif (
462+
current_overload_name is not None
463+
and isinstance(stmt, IfStmt)
464+
and len(stmt.body[0].body) == 1
465+
and isinstance(
466+
stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef))
467+
and stmt.body[0].body[0].name == current_overload_name
468+
):
469+
# IfStmt only contains stmts relevant to current_overload.
470+
# Check if stmts are reachable and add them to current_overload,
471+
# otherwise skip IfStmt to allow subsequent overload
472+
# or function definitions.
473+
infer_reachability_of_if_statement(stmt, self.options)
474+
if stmt.body[0].is_unreachable is True:
475+
continue
476+
if last_if_overload is not None:
477+
if isinstance(last_if_overload, OverloadedFuncDef):
478+
current_overload.extend(last_if_overload.items)
479+
else:
480+
current_overload.append(last_if_overload)
481+
last_if_stmt, last_if_overload = None, None
482+
last_if_overload = None
483+
if isinstance(stmt.body[0].body[0], OverloadedFuncDef):
484+
current_overload.extend(stmt.body[0].body[0].items)
485+
else:
486+
current_overload.append(stmt.body[0].body[0])
453487
else:
488+
if last_if_stmt is not None:
489+
ret.append(last_if_stmt)
490+
last_if_stmt, last_if_overload = None, None
491+
454492
if len(current_overload) == 1:
455493
ret.append(current_overload[0])
456494
elif len(current_overload) > 1:
@@ -464,6 +502,19 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
464502
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
465503
current_overload = [stmt]
466504
current_overload_name = stmt.name
505+
elif (
506+
isinstance(stmt, IfStmt)
507+
and len(stmt.body[0].body) == 1
508+
and isinstance(
509+
stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
510+
and infer_reachability_of_if_statement(
511+
stmt, self.options
512+
) is None # type: ignore[func-returns-value]
513+
and stmt.body[0].is_unreachable is False
514+
):
515+
current_overload_name = stmt.body[0].body[0].name
516+
last_if_stmt = stmt
517+
last_if_overload = stmt.body[0].body[0]
467518
else:
468519
current_overload = []
469520
current_overload_name = None

test-data/unit/check-overloading.test

+157
Original file line numberDiff line numberDiff line change
@@ -5337,3 +5337,160 @@ def register(cls: Any) -> Any: return None
53375337
x = register(Foo)
53385338
reveal_type(x) # N: Revealed type is "builtins.int"
53395339
[builtins fixtures/dict.pyi]
5340+
5341+
[case testOverloadIfBasic]
5342+
# flags: --always-true True
5343+
from typing import overload, Any
5344+
5345+
class A: ...
5346+
class B: ...
5347+
5348+
@overload
5349+
def f1(g: int) -> A: ...
5350+
if True:
5351+
@overload
5352+
def f1(g: str) -> B: ...
5353+
def f1(g: Any) -> Any: ...
5354+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5355+
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"
5356+
5357+
@overload
5358+
def f2(g: int) -> A: ...
5359+
@overload
5360+
def f2(g: bytes) -> A: ...
5361+
if not True:
5362+
@overload
5363+
def f2(g: str) -> B: ...
5364+
def f2(g: Any) -> Any: ...
5365+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5366+
reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \
5367+
# N: Possible overload variants: \
5368+
# N: def f2(g: int) -> A \
5369+
# N: def f2(g: bytes) -> A \
5370+
# N: Revealed type is "Any"
5371+
5372+
[case testOverloadIfSysVersion]
5373+
# flags: --python-version 3.9
5374+
from typing import overload, Any
5375+
import sys
5376+
5377+
class A: ...
5378+
class B: ...
5379+
5380+
@overload
5381+
def f1(g: int) -> A: ...
5382+
if sys.version_info >= (3, 9):
5383+
@overload
5384+
def f1(g: str) -> B: ...
5385+
def f1(g: Any) -> Any: ...
5386+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5387+
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"
5388+
5389+
@overload
5390+
def f2(g: int) -> A: ...
5391+
@overload
5392+
def f2(g: bytes) -> A: ...
5393+
if sys.version_info >= (3, 10):
5394+
@overload
5395+
def f2(g: str) -> B: ...
5396+
def f2(g: Any) -> Any: ...
5397+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5398+
reveal_type(f2("Hello")) # E: No overload variant of "f2" matches argument type "str" \
5399+
# N: Possible overload variants: \
5400+
# N: def f2(g: int) -> A \
5401+
# N: def f2(g: bytes) -> A \
5402+
# N: Revealed type is "Any"
5403+
[builtins fixtures/tuple.pyi]
5404+
5405+
[case testOverloadIfMatching]
5406+
from typing import overload, Any
5407+
5408+
class A: ...
5409+
class B: ...
5410+
class C: ...
5411+
5412+
@overload
5413+
def f1(g: int) -> A: ...
5414+
if True:
5415+
# Some comment
5416+
@overload
5417+
def f1(g: str) -> B: ...
5418+
def f1(g: Any) -> Any: ...
5419+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5420+
reveal_type(f1("Hello")) # N: Revealed type is "__main__.B"
5421+
5422+
@overload
5423+
def f2(g: int) -> A: ...
5424+
if True:
5425+
@overload
5426+
def f2(g: bytes) -> B: ...
5427+
@overload
5428+
def f2(g: str) -> C: ...
5429+
def f2(g: Any) -> Any: ...
5430+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5431+
reveal_type(f2("Hello")) # N: Revealed type is "__main__.C"
5432+
5433+
@overload
5434+
def f3(g: int) -> A: ...
5435+
@overload
5436+
def f3(g: str) -> B: ...
5437+
if True:
5438+
def f3(g: Any) -> Any: ...
5439+
reveal_type(f3(42)) # N: Revealed type is "__main__.A"
5440+
reveal_type(f3("Hello")) # N: Revealed type is "__main__.B"
5441+
5442+
if True:
5443+
@overload
5444+
def f4(g: int) -> A: ...
5445+
@overload
5446+
def f4(g: str) -> B: ...
5447+
def f4(g: Any) -> Any: ...
5448+
reveal_type(f4(42)) # N: Revealed type is "__main__.A"
5449+
reveal_type(f4("Hello")) # N: Revealed type is "__main__.B"
5450+
5451+
if True:
5452+
# Some comment
5453+
@overload
5454+
def f5(g: int) -> A: ...
5455+
@overload
5456+
def f5(g: str) -> B: ...
5457+
def f5(g: Any) -> Any: ...
5458+
reveal_type(f5(42)) # N: Revealed type is "__main__.A"
5459+
reveal_type(f5("Hello")) # N: Revealed type is "__main__.B"
5460+
5461+
[case testOverloadIfNotMatching]
5462+
from typing import overload, Any
5463+
5464+
class A: ...
5465+
class B: ...
5466+
class C: ...
5467+
5468+
@overload # E: An overloaded function outside a stub file must have an implementation
5469+
def f1(g: int) -> A: ...
5470+
@overload
5471+
def f1(g: bytes) -> B: ...
5472+
if True:
5473+
@overload # E: Name "f1" already defined on line 7 \
5474+
# E: Single overload definition, multiple required
5475+
def f1(g: str) -> C: ...
5476+
pass # Some other action
5477+
def f1(g: Any) -> Any: ... # E: Name "f1" already defined on line 7
5478+
reveal_type(f1(42)) # N: Revealed type is "__main__.A"
5479+
reveal_type(f1("Hello")) # E: No overload variant of "f1" matches argument type "str" \
5480+
# N: Possible overload variants: \
5481+
# N: def f1(g: int) -> A \
5482+
# N: def f1(g: bytes) -> B \
5483+
# N: Revealed type is "Any"
5484+
5485+
if True:
5486+
pass # Some other action
5487+
@overload # E: Single overload definition, multiple required
5488+
def f2(g: int) -> A: ...
5489+
@overload # E: Name "f2" already defined on line 21
5490+
def f2(g: bytes) -> B: ...
5491+
@overload
5492+
def f2(g: str) -> C: ...
5493+
def f2(g: Any) -> Any: ...
5494+
reveal_type(f2(42)) # N: Revealed type is "__main__.A"
5495+
reveal_type(f2("Hello")) # N: Revealed type is "__main__.A" \
5496+
# E: Argument 1 to "f2" has incompatible type "str"; expected "int"

0 commit comments

Comments
 (0)