Skip to content

Commit 508ffe8

Browse files
committed
Check for any constrained type
1 parent ead9d11 commit 508ffe8

7 files changed

+124
-51
lines changed

mypy/applytype.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
from typing import List, Dict, Sequence
1+
from typing import List, Dict, Sequence, Tuple
22

33
import mypy.subtypes
44
from mypy.sametypes import is_same_type
55
from mypy.expandtype import expand_type
6-
from mypy.types import Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, Instance
6+
from mypy.types import (
7+
Type, TypeVarId, TypeVarType, TypeVisitor, CallableType, AnyType, PartialType,
8+
Instance, UnionType
9+
)
710
from mypy.messages import MessageBuilder
811
from mypy.nodes import Context
912

@@ -38,10 +41,9 @@ def apply_generic_arguments(callable: CallableType, types: List[Type],
3841
types[i] = value
3942
break
4043
else:
41-
arg_strings = tuple(msg.format(arg_type).replace('"', '')
42-
for arg_type in callable.arg_types)
43-
if has_anystr_incompatible_args(arg_strings, type):
44-
msg.incompatible_anystr_arguments(callable, arg_strings, context)
44+
constraints = get_incompatible_arg_constraints(callable.arg_types, type, i + 1)
45+
if constraints:
46+
msg.incompatible_constrained_arguments(callable, i + 1, constraints, context)
4547
else:
4648
msg.incompatible_typevar_value(callable, i + 1, type, context)
4749
upper_bound = callable.variables[i].upper_bound
@@ -68,14 +70,30 @@ def apply_generic_arguments(callable: CallableType, types: List[Type],
6870
)
6971

7072

71-
def has_anystr_incompatible_args(arg_strings: Sequence[str], type: Type) -> bool:
72-
"""Determines if function has a problem with AnyStr arguments.
73+
def get_incompatible_arg_constraints(arg_types: Sequence[Type], type: Type,
74+
index: int) -> Dict[str, Tuple[str]]:
75+
"""Gets incompatible function arguments with the constrained types.
7376
74-
If the function has more than one AnyStr argument and the solver returns the object type,
75-
then the function was passed both an "str" and "bytes" argument type.
77+
An example of a constrained type is AnyStr which must be all str or all byte.
7678
"""
79+
constraints = {} # type: Dict[str, Tuple[str]]
7780
if isinstance(type, Instance) and type.type.name() == 'object':
78-
for string in arg_strings:
79-
if 'AnyStr' in string:
80-
return True
81-
return False
81+
if index == len(arg_types):
82+
# Index is off by one for '*' arguments
83+
constraints = add_arg_constraints(constraints, arg_types[index - 1])
84+
else:
85+
constraints = add_arg_constraints(constraints, arg_types[index])
86+
return constraints
87+
88+
89+
def add_arg_constraints(constraints: Dict[str, Tuple[str]],
90+
arg_type: Type) -> Dict[str, Tuple[str]]:
91+
if (isinstance(arg_type, TypeVarType) and
92+
arg_type.values and
93+
len(arg_type.values) > 1 and
94+
arg_type.name not in constraints.keys()):
95+
constraints[arg_type.name] = tuple(vals.type.name() for vals in arg_type.values)
96+
elif isinstance(arg_type, UnionType):
97+
for item in arg_type.items:
98+
constraints = add_arg_constraints(constraints, item)
99+
return constraints

mypy/messages.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
import difflib
88

9-
from typing import cast, List, Dict, Any, Sequence, Iterable, Tuple
9+
from typing import cast, List, Dict, Any, Sequence, Iterable, Tuple, Mapping
1010

1111
from mypy.erasetype import erase_type
1212
from mypy.errors import Errors
@@ -806,13 +806,20 @@ def incompatible_typevar_value(self, callee: CallableType, index: int,
806806
self.fail('Type argument {} of {} has incompatible value {}'.format(
807807
index, callable_name(callee), self.format(type)), context)
808808

809-
def incompatible_anystr_arguments(self, callee: CallableType, arg_strings: Sequence[str],
810-
context: Context) -> None:
811-
if len(arg_strings) == 1:
812-
arg_strings = str(arg_strings).replace(',)', ')')
813-
call_with_types = '"{}{}"'.format(callable_name(callee).replace('"', ''), arg_strings)
814-
self.fail('Type arguments of {} have incompatible values'.format(call_with_types), context)
815-
self.note('"AnyStr" arguments must be all "str" or all "bytes"', context)
809+
def incompatible_constrained_arguments(self,
810+
callee: CallableType,
811+
index: int,
812+
constraints: Mapping[str, Sequence[str]],
813+
context: Context) -> None:
814+
for key, values in constraints.items():
815+
self.fail('Type argument {} of {} has incompatible value'.format(
816+
index, callable_name(callee)), context)
817+
if len(values) == 2:
818+
constraint_str = '{} or {}'.format(values[0], values[1])
819+
elif len(values) > 3:
820+
constraint_str = ', '.join(values[:-1]) + ', or ' + values[-1]
821+
self.note('"{}" must be all one type: {}'.format(
822+
key, constraint_str), context)
816823

817824
def overloaded_signatures_overlap(self, index1: int, index2: int,
818825
context: Context) -> None:

test-data/unit/check-functions.test

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,34 +2070,79 @@ def fn(
20702070

20712071
[case testAnyStrIncompatibleArguments]
20722072
from typing import TypeVar
2073-
AnyStr = TypeVar('AnyStr', bytes, str)
2073+
AnyStr = TypeVar('AnyStr', str, bytes)
20742074
def f(x: AnyStr, y: AnyStr) -> None: pass
2075-
def g(x: AnyStr, y: AnyStr, z: int) -> None: pass
2075+
def g(x: AnyStr, y: AnyStr, z: int) -> AnyStr: pass
20762076
f('a', 'b')
20772077
f(b'a', b'b')
2078-
f('a', b'b') # E: Type arguments of "f('AnyStr', 'AnyStr')" have incompatible values \
2079-
# N: "AnyStr" arguments must be all "str" or all "bytes"
2078+
f('a', b'b') # E: Type argument 1 of "f" has incompatible value \
2079+
# N: "AnyStr" must be all one type: str or bytes
20802080
g('a', 'b', 1)
2081-
g('a', b'b', 1) # E: Type arguments of "g('AnyStr', 'AnyStr', 'int')" have incompatible values \
2082-
# N: "AnyStr" arguments must be all "str" or all "bytes"
2083-
g('a', b'b', 'c') # E: Type arguments of "g('AnyStr', 'AnyStr', 'int')" have incompatible values \
2084-
# N: "AnyStr" arguments must be all "str" or all "bytes" \
2081+
g(b'a', b'b', 1)
2082+
g('a', b'b', 1) # E: Type argument 1 of "g" has incompatible value \
2083+
# N: "AnyStr" must be all one type: str or bytes
2084+
g('a', b'b', 'c') # E: Type argument 1 of "g" has incompatible value \
2085+
# N: "AnyStr" must be all one type: str or bytes \
20852086
# E: Argument 3 to "g" has incompatible type "str"; expected "int"
20862087

20872088
[case testUnionAnyStrIncompatibleArguments]
20882089
from typing import TypeVar, Union
2089-
AnyStr = TypeVar('AnyStr', bytes, str)
2090+
AnyStr = TypeVar('AnyStr', str, bytes)
20902091
def f(x: Union[AnyStr, int], y: AnyStr) -> None: pass
20912092
f('a', 'b')
20922093
f(1, 'b')
2093-
f('a', b'b') # E: Type arguments of "f('Union[AnyStr, int]', 'AnyStr')" have incompatible values \
2094-
# N: "AnyStr" arguments must be all "str" or all "bytes"
2094+
f('a', b'b') # E: Type argument 1 of "f" has incompatible value \
2095+
# N: "AnyStr" must be all one type: str or bytes
20952096

20962097
[case testStarAnyStrIncompatibleArguments]
20972098
from typing import TypeVar, Union
2098-
AnyStr = TypeVar('AnyStr', bytes, str)
2099+
AnyStr = TypeVar('AnyStr', str, bytes)
20992100
def f(*x: AnyStr) -> None: pass
2101+
def g(x: int, *y: AnyStr) -> None: pass
2102+
def h(*x: AnyStr, y: int) -> None: pass
21002103
f('a')
21012104
f('a', 'b')
2102-
f('a', b'b') # E: Type arguments of "f('AnyStr')" have incompatible values \
2103-
# N: "AnyStr" arguments must be all "str" or all "bytes"
2105+
f('a', b'b') # E: Type argument 1 of "f" has incompatible value \
2106+
# N: "AnyStr" must be all one type: str or bytes
2107+
g(1, 'a')
2108+
g(1, 'a', b'b') # E: Type argument 1 of "g" has incompatible value \
2109+
# N: "AnyStr" must be all one type: str or bytes
2110+
h('a', y=1)
2111+
h('a', 'b', y=1)
2112+
h('a', b'b', y=1) # E: Type argument 1 of "h" has incompatible value "object"
2113+
2114+
[case testConstrainedIncompatibleArguments]
2115+
from typing import TypeVar
2116+
S = TypeVar('S', int, str)
2117+
def f(x: S, y: S) -> S: return (x + y)
2118+
f('1', '2')
2119+
f('1', 2) # E: Type argument 1 of "f" has incompatible value \
2120+
# N: "S" must be all one type: int or str
2121+
f(1, '2') # E: Type argument 1 of "f" has incompatible value \
2122+
# N: "S" must be all one type: int or str
2123+
2124+
[case testMultipleConstrainedIncompatibleArguments]
2125+
from typing import TypeVar
2126+
S = TypeVar('S', int, str)
2127+
AnyStr = TypeVar('AnyStr', str, bytes)
2128+
def f(a: S, b: S, c: AnyStr, d: AnyStr) -> S: return (a + b)
2129+
f('1', '2', '3', '4')
2130+
f('1', '2', b'3', b'4')
2131+
f(1, 2, '3', '4')
2132+
f(1, 2, b'3', b'4')
2133+
f(1, '2', '3', '4') # E: Type argument 1 of "f" has incompatible value \
2134+
# N: "S" must be all one type: int or str
2135+
f('1', 2, '3', '4') # E: Type argument 1 of "f" has incompatible value \
2136+
# N: "S" must be all one type: int or str
2137+
f('1', '2', b'3', '4') # E: Type argument 2 of "f" has incompatible value \
2138+
# N: "AnyStr" must be all one type: str or bytes
2139+
f('1', '2', '3', b'4') # E: Type argument 2 of "f" has incompatible value \
2140+
# N: "AnyStr" must be all one type: str or bytes
2141+
f('1', 2, b'3', '4') # E: Type argument 1 of "f" has incompatible value \
2142+
# N: "S" must be all one type: int or str \
2143+
# E: Type argument 2 of "f" has incompatible value \
2144+
# N: "AnyStr" must be all one type: str or bytes
2145+
f(1, '2', '3', b'4') # E: Type argument 1 of "f" has incompatible value \
2146+
# N: "S" must be all one type: int or str \
2147+
# E: Type argument 2 of "f" has incompatible value \
2148+
# N: "AnyStr" must be all one type: str or bytes

test-data/unit/check-inference.test

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -750,12 +750,12 @@ AnyStr = TypeVar('AnyStr', bytes, str)
750750
def f(x: Union[AnyStr, int], *a: AnyStr) -> None: pass
751751
f('foo')
752752
f('foo', 'bar')
753-
f('foo', b'bar') # E: Type arguments of "f('Union[AnyStr, int]', 'AnyStr')" have incompatible values \
754-
# N: "AnyStr" arguments must be all "str" or all "bytes"
753+
f('foo', b'bar') # E: Type argument 1 of "f" has incompatible value \
754+
# N: "AnyStr" must be all one type: bytes or str
755755
f(1)
756756
f(1, 'foo')
757-
f(1, 'foo', b'bar') # E: Type arguments of "f('Union[AnyStr, int]', 'AnyStr')" have incompatible values \
758-
# N: "AnyStr" arguments must be all "str" or all "bytes"
757+
f(1, 'foo', b'bar') # E: Type argument 1 of "f" has incompatible value \
758+
# N: "AnyStr" must be all one type: bytes or str
759759
[builtins fixtures/primitives.pyi]
760760

761761

test-data/unit/check-overloading.test

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -998,12 +998,12 @@ def g(x: int, *a: AnyStr) -> None: pass
998998

999999
g('foo')
10001000
g('foo', 'bar')
1001-
g('foo', b'bar') # E: Type arguments of "g('AnyStr', 'AnyStr')" have incompatible values \
1002-
# N: "AnyStr" arguments must be all "str" or all "bytes"
1001+
g('foo', b'bar') # E: Type argument 1 of "g" has incompatible value \
1002+
# N: "AnyStr" must be all one type: bytes or str
10031003
g(1)
10041004
g(1, 'foo')
1005-
g(1, 'foo', b'bar') # E: Type arguments of "g('int', 'AnyStr')" have incompatible values \
1006-
# N: "AnyStr" arguments must be all "str" or all "bytes"
1005+
g(1, 'foo', b'bar') # E: Type argument 1 of "g" has incompatible value \
1006+
# N: "AnyStr" must be all one type: bytes or str
10071007
[builtins fixtures/primitives.pyi]
10081008

10091009
[case testBadOverlapWithTypeVarsWithValues]

test-data/unit/check-typevar-values.test

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ T = TypeVar('T', int, str)
77
def f(x: T) -> None: pass
88
f(1)
99
f('x')
10-
f(object()) # E: Type argument 1 of "f" has incompatible value "object"
10+
f(object()) # E: Type argument 1 of "f" has incompatible value\
11+
# N: "T" must be all one type: int or str
1112

1213
[case testCallGenericFunctionWithTypeVarValueRestrictionUsingContext]
1314
from typing import TypeVar, List
@@ -18,7 +19,8 @@ s = ['x']
1819
o = [object()]
1920
i = f(1)
2021
s = f('')
21-
o = f(1) # E: Type argument 1 of "f" has incompatible value "object"
22+
o = f(1) # E: Type argument 1 of "f" has incompatible value\
23+
# N: "T" must be all one type: int or str
2224
[builtins fixtures/list.pyi]
2325

2426
[case testCallGenericFunctionWithTypeVarValueRestrictionAndAnyArgs]
@@ -239,7 +241,8 @@ class A(Generic[X]):
239241
A(1)
240242
A('x')
241243
A(cast(Any, object()))
242-
A(object()) # E: Type argument 1 of "A" has incompatible value "object"
244+
A(object()) # E: Type argument 1 of "A" has incompatible value\
245+
# N: "X" must be all one type: int or str
243246

244247
[case testGenericTypeWithTypevarValuesAndTypevarArgument]
245248
from typing import TypeVar, Generic

test-data/unit/pythoneval.test

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,8 +1251,8 @@ re.subn(bpat, b'', b'')[0] + b''
12511251
re.subn(bre, lambda m: b'', b'')[0] + b''
12521252
re.subn(bpat, lambda m: b'', b'')[0] + b''
12531253
[out]
1254-
_program.py:7: error: Type arguments of "search('AnyStr', 'AnyStr', 'int')" have incompatible values
1255-
_program.py:7: note: "AnyStr" arguments must be all "str" or all "bytes"
1254+
_program.py:7: error: Type argument 1 of "search" has incompatible value
1255+
_program.py:7: note: "AnyStr" must be all one type: str or bytes
12561256
_program.py:9: error: Cannot infer type argument 1 of "search"
12571257

12581258
[case testReModuleString]
@@ -1276,8 +1276,8 @@ re.subn(spat, '', '')[0] + ''
12761276
re.subn(sre, lambda m: '', '')[0] + ''
12771277
re.subn(spat, lambda m: '', '')[0] + ''
12781278
[out]
1279-
_program.py:7: error: Type arguments of "search('AnyStr', 'AnyStr', 'int')" have incompatible values
1280-
_program.py:7: note: "AnyStr" arguments must be all "str" or all "bytes"
1279+
_program.py:7: error: Type argument 1 of "search" has incompatible value
1280+
_program.py:7: note: "AnyStr" must be all one type: str or bytes
12811281
_program.py:9: error: Cannot infer type argument 1 of "search"
12821282

12831283
[case testListSetitemTuple]

0 commit comments

Comments
 (0)