Skip to content

Commit b644f6c

Browse files
committed
Merge branch 'fix-overload-resolution'
2 parents 0ab7e9f + 38b47b0 commit b644f6c

File tree

7 files changed

+91
-57
lines changed

7 files changed

+91
-57
lines changed

mypy/checkexpr.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -605,16 +605,14 @@ def matches_signature_erased(self, arg_types: List[Type], is_var_arg: bool,
605605
# Use is_more_precise rather than is_subtype because it ignores ducktype
606606
# declarations. This is important since ducktype declarations are ignored
607607
# when performing runtime type checking.
608-
if not is_more_precise(self.erase(arg_types[i]),
609-
self.erase(
610-
callee.arg_types[i])):
608+
if not is_compatible_overload_arg(arg_types[i], callee.arg_types[i]):
611609
return False
612610
# Function varargs.
613611
if callee.is_var_arg:
614612
for i in range(func_fixed, len(arg_types)):
615613
# See above for why we use is_more_precise.
616-
if not is_more_precise(self.erase(arg_types[i]),
617-
self.erase(callee.arg_types[func_fixed])):
614+
if not is_compatible_overload_arg(arg_types[i],
615+
callee.arg_types[func_fixed]):
618616
return False
619617
return True
620618

@@ -1250,10 +1248,6 @@ def unwrap_list(self, a: List[Node]) -> List[Node]:
12501248
r.append(self.strip_parens(n))
12511249
return r
12521250

1253-
def erase(self, type: Type) -> Type:
1254-
"""Replace type variable types in type with Any."""
1255-
return erasetype.erase_type(type)
1256-
12571251

12581252
def is_valid_argc(nargs: int, is_var_arg: bool, callable: Callable) -> bool:
12591253
"""Return a boolean indicating whether a call expression has a
@@ -1401,3 +1395,30 @@ def __init__(self) -> None:
14011395

14021396
def visit_erased_type(self, t: ErasedType) -> bool:
14031397
return True
1398+
1399+
1400+
def is_compatible_overload_arg(actual: Type, formal: Type) -> bool:
1401+
if (isinstance(actual, NoneTyp) or isinstance(actual, AnyType) or
1402+
isinstance(formal, AnyType) or isinstance(formal, TypeVar) or
1403+
isinstance(formal, Callable)):
1404+
# These could match anything at runtime.
1405+
return True
1406+
if isinstance(actual, UnionType):
1407+
return any(is_compatible_overload_arg(item, formal)
1408+
for item in actual.items)
1409+
if isinstance(formal, UnionType):
1410+
return any(is_compatible_overload_arg(actual, item)
1411+
for item in formal.items)
1412+
if isinstance(formal, Instance):
1413+
if isinstance(actual, Callable):
1414+
actual = actual.fallback
1415+
if isinstance(actual, Overloaded):
1416+
actual = actual.items()[0].fallback
1417+
if isinstance(actual, TupleType):
1418+
actual = actual.fallback
1419+
if isinstance(actual, Instance):
1420+
return formal.type in actual.type.mro
1421+
else:
1422+
return False
1423+
# Fall back to a conservative equality check for the remaining kinds of type.
1424+
return is_same_type(erasetype.erase_type(actual), erasetype.erase_type(formal))

mypy/meet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@ def is_overlapping_types(t: Type, s: Type) -> bool:
7171
if tbuiltin in sbuiltin.mro or sbuiltin in tbuiltin.mro:
7272
return True
7373
return tbuiltin == sbuiltin
74-
# We conservatively assume that non-instance types can overlap any other
74+
if isinstance(t, UnionType):
75+
return any(is_overlapping_types(item, s)
76+
for item in t.items)
77+
if isinstance(s, UnionType):
78+
return any(is_overlapping_types(t, item)
79+
for item in s.items)
80+
# We conservatively assume that non-instance, non-union types can overlap any other
7581
# types.
7682
return True
7783

mypy/test/data/check-unions.test

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,21 @@ i = y.foo() # E: Incompatible types in assignment (expression has type "Union[
7979

8080
[case testUnionIndexing]
8181
from typing import Union, List, Undefined
82-
8382
x = Undefined # type: Union[List[int], str]
8483
x[2]
8584
x[2] + 1 # E: Unsupported operand types for + (likely involving Union)
86-
87-
8885
[builtins fixtures/isinstancelist.py]
86+
87+
[case testUnionAsOverloadArg]
88+
from typing import Union, overload
89+
@overload
90+
def f(x: Union[int, str]) -> int: pass
91+
@overload
92+
def f(x: type) -> str: pass
93+
x = 0
94+
x = f(1)
95+
x = f('')
96+
s = ''
97+
s = f(int)
98+
s = f(1) # E: Incompatible types in assignment (expression has type "int", variable has type "str")
99+
x = f(int) # E: Incompatible types in assignment (expression has type "str", variable has type "int")

mypy/test/data/python2eval.test

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,14 @@ print f('ab')
247247
[out]
248248
12
249249
ab
250+
251+
[case testStrAdd_python2]
252+
import typing
253+
s = ''
254+
u = u''
255+
n = 0
256+
n = s + '' # E
257+
s = s + u'' # E
258+
[out]
259+
_program.py, line 5: Incompatible types in assignment (expression has type "str", variable has type "int")
260+
_program.py, line 6: Incompatible types in assignment (expression has type "unicode", variable has type "str")

mypy/test/data/pythoneval.test

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,3 +919,11 @@ map(f, ['x'])
919919
map(f, [1])
920920
[out]
921921
_program.py, line 4: Argument 1 to "map" has incompatible type Function[["str"] -> "str"]; expected Function[["int"] -> "str"]
922+
923+
[case testMapStr]
924+
import typing
925+
x = range(3)
926+
a = list(map(str, x))
927+
a + 1
928+
[out]
929+
_program.py, line 4: Unsupported operand types for + (List[str] and "int")

stubs/2.7/builtins.py

Lines changed: 20 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Undefined, typevar, AbstractGeneric, Iterator, Iterable, overload,
55
Sequence, Mapping, Tuple, List, Any, Dict, Function, Generic, Set,
66
AbstractSet, Sized, Reversible, SupportsInt, SupportsFloat, SupportsAbs,
7-
SupportsRound, IO, BinaryIO, builtinclass, ducktype
7+
SupportsRound, IO, BinaryIO, builtinclass, ducktype, Union
88
)
99
from abc import abstractmethod, ABCMeta
1010

@@ -152,8 +152,12 @@ def __init__(self) -> None: pass
152152
@overload
153153
def __init__(self, o: object) -> None: pass
154154
@overload
155+
def __init__(self, o: str, encoding: str = None, errors: str = 'strict') -> None: pass
156+
@overload
155157
def __init__(self, o: str, encoding: unicode = None, errors: unicode = 'strict') -> None: pass
156158
@overload
159+
def __init__(self, o: bytearray, encoding: str = None, errors: str = 'strict') -> None: pass
160+
@overload
157161
def __init__(self, o: bytearray, encoding: unicode = None,
158162
errors: unicode = 'strict') -> None: pass
159163
def capitalize(self) -> unicode: pass
@@ -232,37 +236,14 @@ def __hash__(self) -> int: pass
232236
class str(Sequence[str]):
233237
def __init__(self, object: object) -> None: pass
234238
def capitalize(self) -> str: pass
235-
@overload
236-
def center(self, width: int, fillchar: str = None) -> str: pass
237-
@overload
238-
def center(self, width: int, fillchar: bytearray = None) -> str: pass
239-
@overload
240-
def count(self, x: unicode) -> int: pass
241-
@overload
242-
def count(self, x: bytearray) -> int: pass
239+
def center(self, width: int, fillchar: Union[str, bytearray] = None) -> str: pass
240+
def count(self, x: Union[unicode, bytearray]) -> int: pass
243241
def decode(self, encoding: unicode = 'utf-8', errors: unicode = 'strict') -> unicode: pass
244242
def encode(self, encoding: unicode = 'utf-8', errors: unicode = 'strict') -> str: pass
245-
@overload
246-
def endswith(self, suffix: unicode) -> bool: pass
247-
@overload
248-
def endswith(self, suffix: bytearray) -> bool: pass
243+
def endswith(self, suffix: Union[unicode, bytearray]) -> bool: pass
249244
def expandtabs(self, tabsize: int = 8) -> str: pass
250-
@overload
251-
def find(self, sub: unicode, start: int = 0) -> int: pass
252-
@overload
253-
def find(self, sub: unicode, start: int, end: int) -> int: pass
254-
@overload
255-
def find(self, sub: bytearray, start: int = 0) -> int: pass
256-
@overload
257-
def find(self, sub: bytearray, start: int, end: int) -> int: pass
258-
@overload
259-
def index(self, sub: unicode, start: int = 0) -> int: pass
260-
@overload
261-
def index(self, sub: unicode, start: int, end: int) -> int: pass
262-
@overload
263-
def index(self, sub: bytearray, start: int = 0) -> int: pass
264-
@overload
265-
def index(self, sub: bytearray, start: int, end: int) -> int: pass
245+
def find(self, sub: Union[unicode, bytearray], start: int = 0, end: int = 0) -> int: pass
246+
def index(self, sub: Union[unicode, bytearray], start: int = 0, end: int = 0) -> int: pass
266247
def isalnum(self) -> bool: pass
267248
def isalpha(self) -> bool: pass
268249
def isdigit(self) -> bool: pass
@@ -274,10 +255,7 @@ def isupper(self) -> bool: pass
274255
def join(self, iterable: Iterable[str]) -> str: pass # TODO unicode
275256
@overload
276257
def join(self, iterable: Iterable[bytearray]) -> str: pass
277-
@overload
278-
def ljust(self, width: int, fillchar: str = None) -> str: pass
279-
@overload
280-
def ljust(self, width: int, fillchar: bytearray = None) -> str: pass
258+
def ljust(self, width: int, fillchar: Union[str, bytearray] = None) -> str: pass
281259
def lower(self) -> str: pass
282260
@overload
283261
def lstrip(self, chars: str = None) -> str: pass # TODO unicode
@@ -307,10 +285,7 @@ def rindex(self, sub: unicode, start: int, end: int) -> int: pass
307285
def rindex(self, sub: bytearray, start: int = 0) -> int: pass
308286
@overload
309287
def rindex(self, sub: bytearray, start: int, end: int) -> int: pass
310-
@overload
311-
def rjust(self, width: int, fillchar: str = None) -> str: pass
312-
@overload
313-
def rjust(self, width: int, fillchar: bytearray = None) -> str: pass
288+
def rjust(self, width: int, fillchar: Union[str, bytearray] = None) -> str: pass
314289
@overload
315290
def rpartition(self, sep: str) -> Tuple[str, str, str]: pass # TODO unicode
316291
@overload
@@ -331,10 +306,7 @@ def split(self, sep: str = None, maxsplit: int = -1) -> List[str]: pass # TODO
331306
def split(self, sep: bytearray = None, # TODO unicode
332307
maxsplit: int = -1) -> List[str]: pass
333308
def splitlines(self, keepends: bool = False) -> List[str]: pass
334-
@overload
335-
def startswith(self, prefix: unicode) -> bool: pass
336-
@overload
337-
def startswith(self, prefix: bytearray) -> bool: pass
309+
def startswith(self, prefix: Union[unicode, bytearray]) -> bool: pass
338310
@overload
339311
def strip(self, chars: str = None) -> str: pass # TODO unicode
340312
@overload
@@ -363,9 +335,11 @@ def __getitem__(self, i: int) -> str: pass
363335
def __getitem__(self, s: slice) -> str: pass
364336
def __getslice__(self, start: int, stop: int) -> str: pass
365337
@overload
366-
def __add__(self, s: str) -> str: pass # TODO unicode
338+
def __add__(self, s: str) -> str: pass
367339
@overload
368340
def __add__(self, s: bytearray) -> str: pass
341+
@overload
342+
def __add__(self, s: unicode) -> unicode: pass
369343
def __mul__(self, n: int) -> str: pass
370344
def __rmul__(self, n: int) -> str: pass
371345
def __contains__(self, o: object) -> bool: pass
@@ -799,9 +773,11 @@ def next(i: Iterator[_T]) -> _T: pass
799773
def next(i: Iterator[_T], default: _T) -> _T: pass
800774
def oct(i: int) -> str: pass # TODO __index__
801775
@overload
802-
def open(file: unicode, mode: unicode = 'r', buffering: int = -1) -> BinaryIO: pass
776+
def open(file: str, mode: str = 'r', buffering: int = -1) -> BinaryIO: pass
777+
@overload
778+
def open(file: unicode, mode: str = 'r', buffering: int = -1) -> BinaryIO: pass
803779
@overload
804-
def open(file: int, mode: unicode = 'r', buffering: int = -1) -> BinaryIO: pass
780+
def open(file: int, mode: str = 'r', buffering: int = -1) -> BinaryIO: pass
805781
@overload
806782
def ord(c: unicode) -> int: pass
807783
@overload

travis.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ echo tests.py
2424
for t in mypy.test.testpythoneval; do
2525
echo $t
2626
"$PYTHON" "$DRIVER" -m $t || fail
27+
"$PYTHON" -m $t || fail
2728
done
2829

2930
# Stub checks

0 commit comments

Comments
 (0)