Skip to content

Commit 8d88223

Browse files
committed
Support overloading with TypedDict
Fix #3609.
1 parent 5d85832 commit 8d88223

File tree

5 files changed

+132
-3
lines changed

5 files changed

+132
-3
lines changed

mypy/checkexpr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2674,6 +2674,12 @@ def overload_arg_similarity(actual: Type, formal: Type) -> int:
26742674
return overload_arg_similarity(actual.ret_type, formal.item)
26752675
else:
26762676
return 0
2677+
if isinstance(actual, TypedDictType):
2678+
if isinstance(formal, TypedDictType):
2679+
# Don't support overloading based on the keys or value types of a TypedDict since
2680+
# that would be complicated and probably only marginally useful.
2681+
return 2
2682+
return overload_arg_similarity(actual.fallback, formal)
26772683
if isinstance(formal, Instance):
26782684
if isinstance(actual, CallableType):
26792685
actual = actual.fallback

mypy/meet.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ class C(A, B): ...
8484
t = t.erase_to_union_or_bound()
8585
if isinstance(s, TypeVarType):
8686
s = s.erase_to_union_or_bound()
87+
if isinstance(t, TypedDictType):
88+
t = t.as_anonymous().fallback
89+
if isinstance(s, TypedDictType):
90+
s = s.as_anonymous().fallback
8791
if isinstance(t, Instance):
8892
if isinstance(s, Instance):
8993
# Consider two classes non-disjoint if one is included in the mro

test-data/unit/check-typeddict.test

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,125 @@ Point = TypedDict('Point', {'x': 1, 'y': 1}) # E: Invalid field type
10321032
[builtins fixtures/dict.pyi]
10331033

10341034

1035+
-- Overloading
1036+
1037+
[case testTypedDictOverloading]
1038+
from typing import overload, Iterable
1039+
from mypy_extensions import TypedDict
1040+
1041+
A = TypedDict('A', {'x': int})
1042+
1043+
@overload
1044+
def f(x: Iterable[str]) -> str: ...
1045+
@overload
1046+
def f(x: int) -> int: ...
1047+
def f(x): pass
1048+
1049+
a: A
1050+
reveal_type(f(a)) # E: Revealed type is 'builtins.str'
1051+
reveal_type(f(1)) # E: Revealed type is 'builtins.int'
1052+
[builtins fixtures/dict.pyi]
1053+
[typing fixtures/typing-full.pyi]
1054+
1055+
[case testTypedDictOverloading2]
1056+
from typing import overload, Iterable
1057+
from mypy_extensions import TypedDict
1058+
1059+
A = TypedDict('A', {'x': int})
1060+
1061+
@overload
1062+
def f(x: Iterable[int]) -> None: ...
1063+
@overload
1064+
def f(x: int) -> None: ...
1065+
def f(x): pass
1066+
1067+
a: A
1068+
f(a) # E: Argument 1 to "f" has incompatible type "A"; expected Iterable[int]
1069+
[builtins fixtures/dict.pyi]
1070+
[typing fixtures/typing-full.pyi]
1071+
1072+
[case testTypedDictOverloading3]
1073+
from typing import overload
1074+
from mypy_extensions import TypedDict
1075+
1076+
A = TypedDict('A', {'x': int})
1077+
1078+
@overload
1079+
def f(x: str) -> None: ...
1080+
@overload
1081+
def f(x: int) -> None: ...
1082+
def f(x): pass
1083+
1084+
a: A
1085+
f(a) # E: No overload variant of "f" matches argument types [TypedDict(x=builtins.int, _fallback=__main__.A)]
1086+
[builtins fixtures/dict.pyi]
1087+
[typing fixtures/typing-full.pyi]
1088+
1089+
[case testTypedDictOverloading4]
1090+
from typing import overload
1091+
from mypy_extensions import TypedDict
1092+
1093+
A = TypedDict('A', {'x': int})
1094+
B = TypedDict('B', {'x': str})
1095+
1096+
@overload
1097+
def f(x: A) -> int: ...
1098+
@overload
1099+
def f(x: int) -> str: ...
1100+
def f(x): pass
1101+
1102+
a: A
1103+
b: B
1104+
reveal_type(f(a)) # E: Revealed type is 'builtins.int'
1105+
reveal_type(f(1)) # E: Revealed type is 'builtins.str'
1106+
f(b) # E: Argument 1 to "f" has incompatible type "B"; expected "A"
1107+
[builtins fixtures/dict.pyi]
1108+
[typing fixtures/typing-full.pyi]
1109+
1110+
[case testTypedDictOverloading5]
1111+
from typing import overload
1112+
from mypy_extensions import TypedDict
1113+
1114+
A = TypedDict('A', {'x': int})
1115+
B = TypedDict('B', {'y': str})
1116+
C = TypedDict('C', {'y': int})
1117+
1118+
@overload
1119+
def f(x: A) -> None: ...
1120+
@overload
1121+
def f(x: B) -> None: ...
1122+
def f(x): pass
1123+
1124+
a: A
1125+
b: B
1126+
c: C
1127+
f(a)
1128+
f(b)
1129+
f(c) # E: Argument 1 to "f" has incompatible type "C"; expected "A"
1130+
[builtins fixtures/dict.pyi]
1131+
[typing fixtures/typing-full.pyi]
1132+
1133+
[case testTypedDictOverloading6]
1134+
from typing import overload
1135+
from mypy_extensions import TypedDict
1136+
1137+
A = TypedDict('A', {'x': int})
1138+
B = TypedDict('B', {'y': str})
1139+
1140+
@overload
1141+
def f(x: A) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
1142+
@overload
1143+
def f(x: B) -> str: ...
1144+
def f(x): pass
1145+
1146+
a: A
1147+
b: B
1148+
reveal_type(f(a)) # E: Revealed type is 'Any'
1149+
reveal_type(f(b)) # E: Revealed type is 'Any'
1150+
[builtins fixtures/dict.pyi]
1151+
[typing fixtures/typing-full.pyi]
1152+
1153+
10351154
-- Special cases
10361155

10371156
[case testForwardReferenceInTypedDict]

test-data/unit/fixtures/dict.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class object:
1111

1212
class type: pass
1313

14-
class dict(Iterable[KT], Mapping[KT, VT], Generic[KT, VT]):
14+
class dict(Mapping[KT, VT], Iterable[KT], Generic[KT, VT]):
1515
@overload
1616
def __init__(self, **kwargs: VT) -> None: pass
1717
@overload

test-data/unit/fixtures/typing-full.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,12 @@ class Sequence(Iterable[T], Generic[T]):
103103
@abstractmethod
104104
def __getitem__(self, n: Any) -> T: pass
105105

106-
class Mapping(Generic[T, U]):
106+
class Mapping(Iterable[T], Generic[T, U]):
107107
@overload
108108
def get(self, k: T) -> Optional[U]: ...
109109
@overload
110110
def get(self, k: T, default: Union[U, V]) -> Union[U, V]: ...
111111

112-
class MutableMapping(Generic[T, U]): pass
112+
class MutableMapping(Mapping[T, U]): pass
113113

114114
TYPE_CHECKING = 1

0 commit comments

Comments
 (0)