Skip to content

Commit cf6a48c

Browse files
authored
Fix nested overload merging (#12607)
Closes #12606
1 parent 9e9de71 commit cf6a48c

File tree

2 files changed

+190
-18
lines changed

2 files changed

+190
-18
lines changed

mypy/fastparse.py

+51-18
Original file line numberDiff line numberDiff line change
@@ -496,18 +496,9 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
496496
if_overload_name: Optional[str] = None
497497
if_block_with_overload: Optional[Block] = None
498498
if_unknown_truth_value: Optional[IfStmt] = None
499-
if (
500-
isinstance(stmt, IfStmt)
501-
and len(stmt.body[0].body) == 1
502-
and seen_unconditional_func_def is False
503-
and (
504-
isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
505-
or current_overload_name is not None
506-
and isinstance(stmt.body[0].body[0], FuncDef)
507-
)
508-
):
499+
if isinstance(stmt, IfStmt) and seen_unconditional_func_def is False:
509500
# Check IfStmt block to determine if function overloads can be merged
510-
if_overload_name = self._check_ifstmt_for_overloads(stmt)
501+
if_overload_name = self._check_ifstmt_for_overloads(stmt, current_overload_name)
511502
if if_overload_name is not None:
512503
if_block_with_overload, if_unknown_truth_value = \
513504
self._get_executable_if_block_with_overloads(stmt)
@@ -553,8 +544,11 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
553544
else:
554545
current_overload.append(last_if_overload)
555546
last_if_stmt, last_if_overload = None, None
556-
if isinstance(if_block_with_overload.body[0], OverloadedFuncDef):
557-
current_overload.extend(if_block_with_overload.body[0].items)
547+
if isinstance(if_block_with_overload.body[-1], OverloadedFuncDef):
548+
skipped_if_stmts.extend(
549+
cast(List[IfStmt], if_block_with_overload.body[:-1])
550+
)
551+
current_overload.extend(if_block_with_overload.body[-1].items)
558552
else:
559553
current_overload.append(
560554
cast(Union[Decorator, FuncDef], if_block_with_overload.body[0])
@@ -600,9 +594,12 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
600594
last_if_stmt = stmt
601595
last_if_stmt_overload_name = None
602596
if if_block_with_overload is not None:
597+
skipped_if_stmts.extend(
598+
cast(List[IfStmt], if_block_with_overload.body[:-1])
599+
)
603600
last_if_overload = cast(
604601
Union[Decorator, FuncDef, OverloadedFuncDef],
605-
if_block_with_overload.body[0]
602+
if_block_with_overload.body[-1]
606603
)
607604
last_if_unknown_truth_value = if_unknown_truth_value
608605
else:
@@ -620,23 +617,38 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
620617
ret.append(current_overload[0])
621618
elif len(current_overload) > 1:
622619
ret.append(OverloadedFuncDef(current_overload))
620+
elif last_if_overload is not None:
621+
ret.append(last_if_overload)
623622
elif last_if_stmt is not None:
624623
ret.append(last_if_stmt)
625624
return ret
626625

627-
def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]:
626+
def _check_ifstmt_for_overloads(
627+
self, stmt: IfStmt, current_overload_name: Optional[str] = None
628+
) -> Optional[str]:
628629
"""Check if IfStmt contains only overloads with the same name.
629630
Return overload_name if found, None otherwise.
630631
"""
631632
# Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef.
632633
# Multiple overloads have already been merged as OverloadedFuncDef.
633634
if not (
634635
len(stmt.body[0].body) == 1
635-
and isinstance(stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef))
636+
and (
637+
isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
638+
or current_overload_name is not None
639+
and isinstance(stmt.body[0].body[0], FuncDef)
640+
)
641+
or len(stmt.body[0].body) > 1
642+
and isinstance(stmt.body[0].body[-1], OverloadedFuncDef)
643+
and all(
644+
self._is_stripped_if_stmt(if_stmt)
645+
for if_stmt in stmt.body[0].body[:-1]
646+
)
636647
):
637648
return None
638649

639-
overload_name = stmt.body[0].body[0].name
650+
overload_name = cast(
651+
Union[Decorator, FuncDef, OverloadedFuncDef], stmt.body[0].body[-1]).name
640652
if stmt.else_body is None:
641653
return overload_name
642654

@@ -649,7 +661,9 @@ def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]:
649661
return overload_name
650662
if (
651663
isinstance(stmt.else_body.body[0], IfStmt)
652-
and self._check_ifstmt_for_overloads(stmt.else_body.body[0]) == overload_name
664+
and self._check_ifstmt_for_overloads(
665+
stmt.else_body.body[0], current_overload_name
666+
) == overload_name
653667
):
654668
return overload_name
655669

@@ -704,6 +718,25 @@ def _strip_contents_from_if_stmt(self, stmt: IfStmt) -> None:
704718
else:
705719
stmt.else_body.body = []
706720

721+
def _is_stripped_if_stmt(self, stmt: Statement) -> bool:
722+
"""Check stmt to make sure it is a stripped IfStmt.
723+
724+
See also: _strip_contents_from_if_stmt
725+
"""
726+
if not isinstance(stmt, IfStmt):
727+
return False
728+
729+
if not (len(stmt.body) == 1 and len(stmt.body[0].body) == 0):
730+
# Body not empty
731+
return False
732+
733+
if not stmt.else_body or len(stmt.else_body.body) == 0:
734+
# No or empty else_body
735+
return True
736+
737+
# For elif, IfStmt are stored recursively in else_body
738+
return self._is_stripped_if_stmt(stmt.else_body.body[0])
739+
707740
def in_method_scope(self) -> bool:
708741
return self.class_and_function_stack[-2:] == ['C', 'F']
709742

test-data/unit/check-overloading.test

+139
Original file line numberDiff line numberDiff line change
@@ -6367,3 +6367,142 @@ def g(x: int) -> str: ...
63676367

63686368
def g(x: int = 0) -> int: # E: Overloaded function implementation cannot produce return type of signature 2
63696369
return x
6370+
6371+
[case testOverloadIfNestedOk]
6372+
# flags: --always-true True --always-false False
6373+
from typing import overload
6374+
6375+
class A: ...
6376+
class B: ...
6377+
class C: ...
6378+
class D: ...
6379+
6380+
@overload
6381+
def f1(g: A) -> A: ...
6382+
if True:
6383+
@overload
6384+
def f1(g: B) -> B: ...
6385+
if True:
6386+
@overload
6387+
def f1(g: C) -> C: ...
6388+
@overload
6389+
def f1(g: D) -> D: ...
6390+
def f1(g): ...
6391+
reveal_type(f1(A())) # N: Revealed type is "__main__.A"
6392+
reveal_type(f1(B())) # N: Revealed type is "__main__.B"
6393+
reveal_type(f1(C())) # N: Revealed type is "__main__.C"
6394+
reveal_type(f1(D())) # N: Revealed type is "__main__.D"
6395+
6396+
@overload
6397+
def f2(g: A) -> A: ...
6398+
if True:
6399+
@overload
6400+
def f2(g: B) -> B: ...
6401+
if True:
6402+
@overload
6403+
def f2(g: C) -> C: ...
6404+
if True:
6405+
@overload
6406+
def f2(g: D) -> D: ...
6407+
def f2(g): ...
6408+
reveal_type(f2(A())) # N: Revealed type is "__main__.A"
6409+
reveal_type(f2(B())) # N: Revealed type is "__main__.B"
6410+
reveal_type(f2(C())) # N: Revealed type is "__main__.C"
6411+
reveal_type(f2(D())) # N: Revealed type is "__main__.D"
6412+
6413+
@overload
6414+
def f3(g: A) -> A: ...
6415+
if True:
6416+
if True:
6417+
@overload
6418+
def f3(g: B) -> B: ...
6419+
if True:
6420+
@overload
6421+
def f3(g: C) -> C: ...
6422+
def f3(g): ...
6423+
reveal_type(f3(A())) # N: Revealed type is "__main__.A"
6424+
reveal_type(f3(B())) # N: Revealed type is "__main__.B"
6425+
reveal_type(f3(C())) # N: Revealed type is "__main__.C"
6426+
6427+
@overload
6428+
def f4(g: A) -> A: ...
6429+
if True:
6430+
if False:
6431+
@overload
6432+
def f4(g: B) -> B: ...
6433+
else:
6434+
@overload
6435+
def f4(g: C) -> C: ...
6436+
def f4(g): ...
6437+
reveal_type(f4(A())) # N: Revealed type is "__main__.A"
6438+
reveal_type(f4(B())) # E: No overload variant of "f4" matches argument type "B" \
6439+
# N: Possible overload variants: \
6440+
# N: def f4(g: A) -> A \
6441+
# N: def f4(g: C) -> C \
6442+
# N: Revealed type is "Any"
6443+
reveal_type(f4(C())) # N: Revealed type is "__main__.C"
6444+
6445+
@overload
6446+
def f5(g: A) -> A: ...
6447+
if True:
6448+
if False:
6449+
@overload
6450+
def f5(g: B) -> B: ...
6451+
elif True:
6452+
@overload
6453+
def f5(g: C) -> C: ...
6454+
def f5(g): ...
6455+
reveal_type(f5(A())) # N: Revealed type is "__main__.A"
6456+
reveal_type(f5(B())) # E: No overload variant of "f5" matches argument type "B" \
6457+
# N: Possible overload variants: \
6458+
# N: def f5(g: A) -> A \
6459+
# N: def f5(g: C) -> C \
6460+
# N: Revealed type is "Any"
6461+
reveal_type(f5(C())) # N: Revealed type is "__main__.C"
6462+
6463+
[case testOverloadIfNestedFailure]
6464+
# flags: --always-true True --always-false False
6465+
from typing import overload
6466+
6467+
class A: ...
6468+
class B: ...
6469+
class C: ...
6470+
class D: ...
6471+
6472+
@overload # E: Single overload definition, multiple required
6473+
def f1(g: A) -> A: ...
6474+
if True:
6475+
@overload # E: Single overload definition, multiple required
6476+
def f1(g: B) -> B: ...
6477+
if maybe_true: # E: Condition cannot be inferred, unable to merge overloads \
6478+
# E: Name "maybe_true" is not defined
6479+
@overload
6480+
def f1(g: C) -> C: ...
6481+
@overload
6482+
def f1(g: D) -> D: ...
6483+
def f1(g): ... # E: Name "f1" already defined on line 9
6484+
6485+
@overload # E: Single overload definition, multiple required
6486+
def f2(g: A) -> A: ...
6487+
if True:
6488+
if False:
6489+
@overload
6490+
def f2(g: B) -> B: ...
6491+
elif maybe_true: # E: Name "maybe_true" is not defined
6492+
@overload # E: Single overload definition, multiple required
6493+
def f2(g: C) -> C: ...
6494+
def f2(g): ... # E: Name "f2" already defined on line 21
6495+
6496+
@overload # E: Single overload definition, multiple required
6497+
def f3(g: A) -> A: ...
6498+
if True:
6499+
@overload # E: Single overload definition, multiple required
6500+
def f3(g: B) -> B: ...
6501+
if True:
6502+
pass # Some other node
6503+
@overload # E: Name "f3" already defined on line 32 \
6504+
# E: An overloaded function outside a stub file must have an implementation
6505+
def f3(g: C) -> C: ...
6506+
@overload
6507+
def f3(g: D) -> D: ...
6508+
def f3(g): ... # E: Name "f3" already defined on line 32

0 commit comments

Comments
 (0)