From 3df1b06d73ff4a9e7a813f6bb0509e4c99ce0317 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Fri, 30 Mar 2018 01:26:07 -0700 Subject: [PATCH 1/3] Adds support for basic union math with overloads This commit adds support for very basic and simple union math when calling overloaded functions, resolving #4576. As a side effect, this change also fixes a bug where calling overloaded functions can sometimes silently infer a return type of 'Any' and slightly modifies the semantics of how mypy handles overlaps in overloaded functions. Details on specific changes made: 1. The new algorithm works by modifying checkexpr.overload_call_targets to return all possible matches, rather then just one. We start by trying the first matching signature. If there was some error, we (conservatively) attempt to union all of the matching signatures together and repeat the typechecking process. If it doesn't seem like it's possible to combine the matching signatures in a sound way, we end and just output the errors we obtained from typechecking the first match. The "signature-unioning" code is currently deliberately very conservative. I figured it was better to start small and attempt to handle only basic cases like #1943 and relax the restrictions later as needed. For more details on this algorithm, see the comments in checkexpr.union_overload_matches. 2. This change incidentally resolves any bugs related to how calling an overloaded function can sometimes silently infer a return type of Any. Previously, if a function call caused an overload to be less precise then a previous one, we gave up and returned a silent Any. This change removes this case altogether and only infers Any if either (a) the caller arguments explicitly contains Any or (b) if there was some error. For example, see #3295 and #1322 -- I believe this pull request touches on and maybe resolves (??) those two issues. 3. As a result, this caused a few errors in mypy where code was relying on this "silently infer Any" behavior -- see the changes in checker.py and semanal.py. Both files were using expressions of the form `zip(*iterable)`, which ended up having a type of `Any` under the old algorithm. The new algorithm will instead infer `Iterable[Tuple[Any, ...]]` which actually matches the stubs in typeshed. 4. Many of the attrs tests were also relying on the same behavior. Specifically, these changes cause the attr stubs in `test-data/unit/lib-stub` to no longer work. It seemed that expressions of the form `a = attr.ib()` were evaluated to 'Any' not because of a stub, but because of the 'silent Any' bug. I couldn't find a clean way of fixing the stubs to infer the correct thing under this new behavior, so just gave up and removed the overloads altogether. I think this is fine though -- it seems like the attrs plugin infers the correct type for us anyways, regardless of what the stubs say. If this pull request is accepted, I plan on submitting a similar pull request to the stubs in typeshed. 4. This pull request also probably touches on https://github.com/python/typing/issues/253. We still require the overloads to be written from the most narrow to general and disallow overlapping signatures. However, if a *call* now causes overlaps, we try the "union" algorithm described above and default to selecting the first matching overload instead of giving up. --- mypy/checker.py | 4 +- mypy/checkexpr.py | 209 +++++++++++++++++++------- mypy/messages.py | 14 +- mypy/semanal.py | 2 +- mypy/types.py | 11 +- test-data/unit/check-overloading.test | 96 +++++++++++- test-data/unit/check-protocols.test | 3 +- test-data/unit/check-typeddict.test | 6 +- test-data/unit/lib-stub/attr.pyi | 18 +-- 9 files changed, 278 insertions(+), 85 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 99ef6c8e6bd6..8527f544cd60 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1801,8 +1801,8 @@ def check_multi_assignment_from_union(self, lvalues: List[Expression], rvalue: E expr = expr.expr types, declared_types = zip(*items) self.binder.assign_type(expr, - UnionType.make_simplified_union(types), - UnionType.make_simplified_union(declared_types), + UnionType.make_simplified_union(list(types)), + UnionType.make_simplified_union(list(declared_types)), False) for union, lv in zip(union_types, self.flatten_lvalues(lvalues)): # Properly store the inferred types. diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 1066fcc758ae..35497071a88b 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -611,10 +611,63 @@ def check_call(self, callee: Type, args: List[Expression], arg_types = self.infer_arg_types_in_context(None, args) self.msg.enable_errors() - target = self.overload_call_target(arg_types, arg_kinds, arg_names, - callee, context, - messages=arg_messages) - return self.check_call(target, args, arg_kinds, context, arg_names, + overload_messages = arg_messages.copy() + targets = self.overload_call_targets(arg_types, arg_kinds, arg_names, + callee, context, + messages=overload_messages) + + # If there are multiple targets, that means that there were + # either multiple possible matches or the types were overlapping in some + # way. In either case, we default to picking the first match and + # see what happens if we try using it. + # + # Note: if we pass in an argument that inherits from two overloaded + # types, we default to picking the first match. For example: + # + # class A: pass + # class B: pass + # class C(A, B): pass + # + # @overload + # def f(x: A) -> int: ... + # @overload + # def f(x: B) -> str: ... + # def f(x): ... + # + # reveal_type(f(C())) # Will be 'int', not 'Union[int, str]' + # + # It's unclear if this is really the best thing to do, but multiple + # inheritance is rare. See the docstring of mypy.meet.is_overlapping_types + # for more about this. + + original_output = self.check_call(targets[0], args, arg_kinds, context, arg_names, + arg_messages=overload_messages, + callable_name=callable_name, + object_type=object_type) + + if not overload_messages.is_errors() or len(targets) == 1: + # If there were no errors or if there was only one match, we can end now. + # + # Note that if we have only one target, there's nothing else we + # can try doing. In that case, we just give up and return early + # and skip the below steps. + arg_messages.add_errors(overload_messages) + return original_output + + # Otherwise, we attempt to synthesize together a new callable by combining + # together the different matches by union-ing together their arguments + # and return type. + + targets = cast(List[CallableType], targets) + unioned_callable = self.union_overload_matches(targets) + if unioned_callable is None: + # If it was not possible to actually combine together the + # callables in a sound way, we give up and return the original + # error message. + arg_messages.add_errors(overload_messages) + return original_output + + return self.check_call(unioned_callable, args, arg_kinds, context, arg_names, arg_messages=arg_messages, callable_name=callable_name, object_type=object_type) @@ -1089,7 +1142,7 @@ def check_arg(self, caller_type: Type, original_caller_type: Type, (callee_type.item.type.is_abstract or callee_type.item.type.is_protocol) and # ...except for classmethod first argument not caller_type.is_classmethod_class): - self.msg.concrete_only_call(callee_type, context) + messages.concrete_only_call(callee_type, context) elif not is_subtype(caller_type, callee_type): if self.chk.should_suppress_optional_error([caller_type, callee_type]): return @@ -1097,75 +1150,115 @@ def check_arg(self, caller_type: Type, original_caller_type: Type, caller_kind, context) if (isinstance(original_caller_type, (Instance, TupleType, TypedDictType)) and isinstance(callee_type, Instance) and callee_type.type.is_protocol): - self.msg.report_protocol_problems(original_caller_type, callee_type, context) + messages.report_protocol_problems(original_caller_type, callee_type, context) if (isinstance(callee_type, CallableType) and isinstance(original_caller_type, Instance)): call = find_member('__call__', original_caller_type, original_caller_type) if call: - self.msg.note_call(original_caller_type, call, context) - - def overload_call_target(self, arg_types: List[Type], arg_kinds: List[int], - arg_names: Optional[Sequence[Optional[str]]], - overload: Overloaded, context: Context, - messages: Optional[MessageBuilder] = None) -> Type: - """Infer the correct overload item to call with given argument types. - - The return value may be CallableType or AnyType (if an unique item - could not be determined). + messages.note_call(original_caller_type, call, context) + + def overload_call_targets(self, arg_types: List[Type], arg_kinds: List[int], + arg_names: Optional[Sequence[Optional[str]]], + overload: Overloaded, context: Context, + messages: Optional[MessageBuilder] = None) -> Sequence[Type]: + """Infer all possible overload targets to call with given argument types. + The list is guaranteed be one of the following: + + 1. A List[CallableType] of length 1 if we were able to find an + unambiguous best match. + 2. A List[AnyType] of length 1 if we were unable to find any match + or discovered the match was ambiguous due to conflicting Any types. + 3. A List[CallableType] of length 2 or more if there were multiple + plausible matches. The matches are returned in the order they + were defined. """ messages = messages or self.msg - # TODO: For overlapping signatures we should try to get a more precise - # result than 'Any'. match = [] # type: List[CallableType] best_match = 0 for typ in overload.items(): similarity = self.erased_signature_similarity(arg_types, arg_kinds, arg_names, typ, context=context) if similarity > 0 and similarity >= best_match: - if (match and not is_same_type(match[-1].ret_type, - typ.ret_type) and - (not mypy.checker.is_more_precise_signature(match[-1], typ) - or (any(isinstance(arg, AnyType) for arg in arg_types) - and any_arg_causes_overload_ambiguity( - match + [typ], arg_types, arg_kinds, arg_names)))): - # Ambiguous return type. Either the function overload is - # overlapping (which we don't handle very well here) or the - # caller has provided some Any argument types; in either - # case we'll fall back to Any. It's okay to use Any types - # in calls. - # - # Overlapping overload items are generally fine if the - # overlapping is only possible when there is multiple - # inheritance, as this is rare. See docstring of - # mypy.meet.is_overlapping_types for more about this. - # - # Note that there is no ambiguity if the items are - # covariant in both argument types and return types with - # respect to type precision. We'll pick the best/closest - # match. - # - # TODO: Consider returning a union type instead if the - # overlapping is NOT due to Any types? - return AnyType(TypeOfAny.special_form) - else: - match.append(typ) + if (match and not is_same_type(match[-1].ret_type, typ.ret_type) + and any(isinstance(arg, AnyType) for arg in arg_types) + and any_arg_causes_overload_ambiguity( + match + [typ], arg_types, arg_kinds, arg_names)): + # Ambiguous return type. The caller has provided some + # Any argument types (which are okay to use in calls), + # so we fall back to returning 'Any'. + return [AnyType(TypeOfAny.special_form)] + match.append(typ) best_match = max(best_match, similarity) - if not match: + + if len(match) == 0: if not self.chk.should_suppress_optional_error(arg_types): messages.no_variant_matches_arguments(overload, arg_types, context) - return AnyType(TypeOfAny.from_error) + return [AnyType(TypeOfAny.from_error)] + elif len(match) == 1: + return match else: - if len(match) == 1: - return match[0] - else: - # More than one signature matches. Pick the first *non-erased* - # matching signature, or default to the first one if none - # match. - for m in match: - if self.match_signature_types(arg_types, arg_kinds, arg_names, m, - context=context): - return m - return match[0] + # More than one signature matches or the signatures are + # overlapping. In either case, we return all of the matching + # signatures and let the caller decide what to do with them. + out = [m for m in match if self.match_signature_types( + arg_types, arg_kinds, arg_names, m, context=context)] + return out if len(out) >= 1 else match + + def union_overload_matches(self, callables: List[CallableType]) -> Optional[CallableType]: + """Accepts a list of overload signatures and attempts to combine them together into a + new CallableType consisting of the union of all of the given arguments and return types. + + Returns None if it is not possible to combine the different callables together in a + sound manner.""" + + new_args: List[List[Type]] = [[] for _ in range(len(callables[0].arg_types))] + + expected_names = callables[0].arg_names + expected_kinds = callables[0].arg_kinds + + for target in callables: + if target.arg_names != expected_names or target.arg_kinds != expected_kinds: + # We conservatively end if the overloads do not have the exact same signature. + # TODO: Enhance the union overload logic to handle a wider variety of signatures. + return None + + for i, arg in enumerate(target.arg_types): + new_args[i].append(arg) + + union_count = 0 + final_args = [] + for args in new_args: + new_type = UnionType.make_simplified_union(args) + union_count += 1 if isinstance(new_type, UnionType) else 0 + final_args.append(new_type) + + # TODO: Modify this check to be less conservative. + # + # Currently, we permit only one union union in the arguments because if we allow + # multiple, we can't always guarantee the synthesized callable will be correct. + # + # For example, suppose we had the following two overloads: + # + # @overload + # def f(x: A, y: B) -> None: ... + # @overload + # def f(x: B, y: A) -> None: ... + # + # If we continued and synthesize "def f(x: Union[A,B], y: Union[A,B]) -> None: ...", + # then we'd incorrectly accept calls like "f(A(), A())" when they really ought to + # be rejected. + # + # However, that means we'll also give up if the original overloads contained + # any unions. This is likely unnecessary -- we only really need to give up if + # there are more then one *synthesized* union arguments. + if union_count >= 2: + return None + + return callables[0].copy_modified( + arg_types=final_args, + ret_type=UnionType.make_simplified_union([t.ret_type for t in callables]), + implicit=True, + from_overloads=True) def erased_signature_similarity(self, arg_types: List[Type], arg_kinds: List[int], arg_names: Optional[Sequence[Optional[str]]], diff --git a/mypy/messages.py b/mypy/messages.py index c58b2d84bdb2..8184610883e8 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -629,8 +629,19 @@ def incompatible_argument(self, n: int, m: int, callee: CallableType, arg_type: expected_type = callee.arg_types[m - 1] except IndexError: # Varargs callees expected_type = callee.arg_types[-1] + arg_type_str, expected_type_str = self.format_distinctly( arg_type, expected_type, bare=True) + expected_type_str = self.quote_type_string(expected_type_str) + + if callee.from_overloads and isinstance(expected_type, UnionType): + expected_formatted = [] + for e in expected_type.items: + type_str = self.format_distinctly(arg_type, e, bare=True)[1] + expected_formatted.append(self.quote_type_string(type_str)) + expected_type_str = 'one of {} based on available overloads'.format( + ', '.join(expected_formatted)) + if arg_kind == ARG_STAR: arg_type_str = '*' + arg_type_str elif arg_kind == ARG_STAR2: @@ -645,8 +656,7 @@ def incompatible_argument(self, n: int, m: int, callee: CallableType, arg_type: arg_label = '"{}"'.format(arg_name) msg = 'Argument {} {}has incompatible type {}; expected {}'.format( - arg_label, target, self.quote_type_string(arg_type_str), - self.quote_type_string(expected_type_str)) + arg_label, target, self.quote_type_string(arg_type_str), expected_type_str) if isinstance(arg_type, Instance) and isinstance(expected_type, Instance): notes = append_invariance_notes(notes, arg_type, expected_type) self.fail(msg, context) diff --git a/mypy/semanal.py b/mypy/semanal.py index a8469ae53ba3..988f4eb3e663 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -2204,7 +2204,7 @@ def process_module_assignment(self, lvals: List[Lvalue], rval: Expression, # about the length mismatch in type-checking. elementwise_assignments = zip(rval.items, *[v.items for v in seq_lvals]) for rv, *lvs in elementwise_assignments: - self.process_module_assignment(lvs, rv, ctx) + self.process_module_assignment(list(lvs), rv, ctx) elif isinstance(rval, RefExpr): rnode = self.lookup_type_node(rval) if rnode and rnode.kind == MODULE_REF: diff --git a/mypy/types.py b/mypy/types.py index 85c0e5aae906..afcc703bcf85 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -661,6 +661,8 @@ class CallableType(FunctionLike): special_sig = None # type: Optional[str] # Was this callable generated by analyzing Type[...] instantiation? from_type_type = False # type: bool + # Was this callable generated by synthesizing multiple overloads? + from_overloads = False # type: bool bound_args = None # type: List[Optional[Type]] @@ -680,6 +682,7 @@ def __init__(self, is_classmethod_class: bool = False, special_sig: Optional[str] = None, from_type_type: bool = False, + from_overloads: bool = False, bound_args: Optional[List[Optional[Type]]] = None, ) -> None: assert len(arg_types) == len(arg_kinds) == len(arg_names) @@ -704,6 +707,7 @@ def __init__(self, self.is_classmethod_class = is_classmethod_class self.special_sig = special_sig self.from_type_type = from_type_type + self.from_overloads = from_overloads self.bound_args = bound_args or [] super().__init__(line, column) @@ -719,8 +723,10 @@ def copy_modified(self, line: int = _dummy, column: int = _dummy, is_ellipsis_args: bool = _dummy, + implicit: bool = _dummy, special_sig: Optional[str] = _dummy, from_type_type: bool = _dummy, + from_overloads: bool = _dummy, bound_args: List[Optional[Type]] = _dummy) -> 'CallableType': return CallableType( arg_types=arg_types if arg_types is not _dummy else self.arg_types, @@ -735,10 +741,11 @@ def copy_modified(self, column=column if column is not _dummy else self.column, is_ellipsis_args=( is_ellipsis_args if is_ellipsis_args is not _dummy else self.is_ellipsis_args), - implicit=self.implicit, + implicit=implicit if implicit is not _dummy else self.implicit, is_classmethod_class=self.is_classmethod_class, special_sig=special_sig if special_sig is not _dummy else self.special_sig, from_type_type=from_type_type if from_type_type is not _dummy else self.from_type_type, + from_overloads=from_overloads if from_overloads is not _dummy else self.from_overloads, bound_args=bound_args if bound_args is not _dummy else self.bound_args, ) @@ -890,6 +897,7 @@ def serialize(self) -> JsonDict: 'is_ellipsis_args': self.is_ellipsis_args, 'implicit': self.implicit, 'is_classmethod_class': self.is_classmethod_class, + 'from_overloads': self.from_overloads, 'bound_args': [(None if t is None else t.serialize()) for t in self.bound_args], } @@ -908,6 +916,7 @@ def deserialize(cls, data: JsonDict) -> 'CallableType': is_ellipsis_args=data['is_ellipsis_args'], implicit=data['implicit'], is_classmethod_class=data['is_classmethod_class'], + from_overloads=data['from_overloads'], bound_args=[(None if t is None else deserialize_type(t)) for t in data['bound_args']], ) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 88eabcaf4a61..b4d1e07347d2 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -632,7 +632,7 @@ n = 1 m = 1 n = 'x' # E: Incompatible types in assignment (expression has type "str", variable has type "int") m = 'x' # E: Incompatible types in assignment (expression has type "str", variable has type "int") -f(list_object) # E: Argument 1 to "f" has incompatible type "List[object]"; expected "List[int]" +f(list_object) # E: Argument 1 to "f" has incompatible type "List[object]"; expected one of "List[int]", "List[str]" based on available overloads [builtins fixtures/list.pyi] [case testOverlappingOverloadSignatures] @@ -1490,3 +1490,97 @@ class Child4(ParentWithDynamicImpl): [builtins fixtures/tuple.pyi] +[case testOverloadInferUnionReturnBasic] +from typing import overload, Union + +class A: ... +class B: ... +class C: ... +class D: ... + +@overload +def f1(x: A) -> B: ... +@overload +def f1(x: C) -> D: ... +def f1(x): ... + +arg1: Union[A, C] +reveal_type(f1(arg1)) # E: Revealed type is 'Union[__main__.B, __main__.D]' + +@overload +def f2(x: A) -> B: ... +@overload +def f2(x: C) -> B: ... +def f2(x): ... + +reveal_type(f2(arg1)) # E: Revealed type is '__main__.B' + +[builtins fixtures/tuple.pyi] + +[case testOverloadInferUnionReturnMultipleArguments] +from typing import overload, Union + +class A: ... +class B: ... +class C: ... +class D: ... + +@overload +def f1(x: A, y: C) -> B: ... +@overload +def f1(x: C, y: A) -> D: ... +def f1(x, y): ... + +# TODO: Instead of defaulting to picking the first overload, display a nicer error message +arg1: Union[A, C] +reveal_type(f1(arg1, arg1)) + +@overload +def f2(x: A, y: C) -> B: ... +@overload +def f2(x: C, y: C) -> D: ... +def f2(x, y): ... + +reveal_type(f2(arg1, arg1)) +reveal_type(f2(arg1, C())) + +[builtins fixtures/tuple.pyi] +[out] +main:16: error: Revealed type is '__main__.B' +main:16: error: Argument 1 to "f1" has incompatible type "Union[A, C]"; expected "A" +main:16: error: Argument 2 to "f1" has incompatible type "Union[A, C]"; expected "C" +main:24: error: Revealed type is 'Union[__main__.B, __main__.D]' +main:24: error: Argument 2 to "f2" has incompatible type "Union[A, C]"; expected "C" +main:25: error: Revealed type is 'Union[__main__.B, __main__.D]' + +[case testOverloadInferUnionReturnFunctionsWithKwargs] +from typing import overload, Union, Optional + +class A: ... +class B: ... +class C: ... + +@overload +def f(x: A) -> A: ... +@overload +def f(x: A, y: Optional[B] = None) -> C: ... +@overload +def f(x: A, z: Optional[C] = None) -> B: ... +def f(x, y=None, z=None): ... + +reveal_type(f(A(), B())) +reveal_type(f(A(), C())) + +arg: Union[B, C] +reveal_type(f(A(), arg)) + +reveal_type(f(A())) + +[builtins fixtures/tuple.pyi] +[out] +main:15: error: Revealed type is '__main__.C' +main:16: error: Revealed type is '__main__.B' +main:19: error: Revealed type is '__main__.C' +main:19: error: Argument 2 to "f" has incompatible type "Union[B, C]"; expected "Optional[B]" +main:21: error: Revealed type is '__main__.A' + diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index 976720422002..19868461e692 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -1301,8 +1301,7 @@ def f(x): reveal_type(f(C1())) # E: Revealed type is 'builtins.int' reveal_type(f(C2())) # E: Revealed type is 'builtins.str' class D(C1, C2): pass # Compatible with both P1 and P2 -# FIXME: the below is not right, see #1322 -reveal_type(f(D())) # E: Revealed type is 'Any' +reveal_type(f(D())) # E: Revealed type is 'builtins.int' f(C()) # E: No overload variant of "f" matches argument types [__main__.C] [builtins fixtures/isinstance.pyi] diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index ea4fa97f74ea..e9cb7c69550d 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -1249,7 +1249,7 @@ b: B c: C f(a) f(b) -f(c) # E: Argument 1 to "f" has incompatible type "C"; expected "A" +f(c) # E: Argument 1 to "f" has incompatible type "C"; expected one of "A", "B" based on available overloads [builtins fixtures/dict.pyi] [typing fixtures/typing-full.pyi] @@ -1268,8 +1268,8 @@ def f(x): pass a: A b: B -reveal_type(f(a)) # E: Revealed type is 'Any' -reveal_type(f(b)) # E: Revealed type is 'Any' +reveal_type(f(a)) # E: Revealed type is 'builtins.int' +reveal_type(f(b)) # E: Revealed type is 'builtins.str' [builtins fixtures/dict.pyi] [typing fixtures/typing-full.pyi] diff --git a/test-data/unit/lib-stub/attr.pyi b/test-data/unit/lib-stub/attr.pyi index d62a99a685eb..97bd9705a944 100644 --- a/test-data/unit/lib-stub/attr.pyi +++ b/test-data/unit/lib-stub/attr.pyi @@ -3,28 +3,16 @@ from typing import TypeVar, overload, Callable, Any, Type, Optional _T = TypeVar('_T') _C = TypeVar('_C', bound=type) -@overload def attr(default: Optional[_T] = ..., - validator: Optional[Any] = ..., - repr: bool = ..., - cmp: bool = ..., - hash: Optional[bool] = ..., - init: bool = ..., - convert: Optional[Callable[[Any], _T]] = ..., - metadata: Any = ..., - type: Optional[Type[_T]] = ..., - converter: Optional[Callable[[Any], _T]] = ...) -> _T: ... -@overload -def attr(default: None = ..., - validator: None = ..., + validator: Any = ..., repr: bool = ..., cmp: bool = ..., hash: Optional[bool] = ..., init: bool = ..., convert: Optional[Callable[[Any], _T]] = ..., metadata: Any = ..., - type: None = ..., - converter: None = ...) -> Any: ... + type: Type[_T] = ..., + converter: Optional[Callable[[Any], _T]] = ...) -> Any: ... @overload def attributes(maybe_cls: _C, From 5919eb5102f10d9b2da9557ca5351355111e397a Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Mon, 2 Apr 2018 16:42:15 -0700 Subject: [PATCH 2/3] Fixes checkexpr to use Python 3.4/3.5 compatible syntax --- mypy/checkexpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 35497071a88b..e1fa15832f10 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1211,7 +1211,7 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call Returns None if it is not possible to combine the different callables together in a sound manner.""" - new_args: List[List[Type]] = [[] for _ in range(len(callables[0].arg_types))] + new_args = [[] for _ in range(len(callables[0].arg_types))] # type: List[List[Type]] expected_names = callables[0].arg_names expected_kinds = callables[0].arg_kinds From a5d181ba5135c9a263928a7438dd2ea1d675b8aa Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Wed, 11 Apr 2018 08:29:34 -0700 Subject: [PATCH 3/3] Correctly handle generics when doing overload union math --- mypy/checkexpr.py | 29 +++++-- test-data/unit/check-overloading.test | 104 ++++++++++++++++++++------ 2 files changed, 106 insertions(+), 27 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index e1fa15832f10..4873b158c8da 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -659,7 +659,8 @@ def check_call(self, callee: Type, args: List[Expression], # and return type. targets = cast(List[CallableType], targets) - unioned_callable = self.union_overload_matches(targets) + unioned_callable = self.union_overload_matches( + targets, args, arg_kinds, arg_names, context) if unioned_callable is None: # If it was not possible to actually combine together the # callables in a sound way, we give up and return the original @@ -1204,14 +1205,18 @@ def overload_call_targets(self, arg_types: List[Type], arg_kinds: List[int], arg_types, arg_kinds, arg_names, m, context=context)] return out if len(out) >= 1 else match - def union_overload_matches(self, callables: List[CallableType]) -> Optional[CallableType]: + def union_overload_matches(self, callables: List[CallableType], + args: List[Expression], + arg_kinds: List[int], + arg_names: Optional[Sequence[Optional[str]]], + context: Context) -> Optional[CallableType]: """Accepts a list of overload signatures and attempts to combine them together into a new CallableType consisting of the union of all of the given arguments and return types. Returns None if it is not possible to combine the different callables together in a sound manner.""" - new_args = [[] for _ in range(len(callables[0].arg_types))] # type: List[List[Type]] + new_returns = [] # type: List[Type] expected_names = callables[0].arg_names expected_kinds = callables[0].arg_kinds @@ -1222,13 +1227,25 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call # TODO: Enhance the union overload logic to handle a wider variety of signatures. return None + if target.is_generic(): + formal_to_actual = map_actuals_to_formals( + arg_kinds, arg_names, + target.arg_kinds, target.arg_names, + lambda i: self.accept(args[i])) + + target = freshen_function_type_vars(target) + target = self.infer_function_type_arguments_using_context(target, context) + target = self.infer_function_type_arguments( + target, args, arg_kinds, formal_to_actual, context) + for i, arg in enumerate(target.arg_types): new_args[i].append(arg) + new_returns.append(target.ret_type) union_count = 0 final_args = [] - for args in new_args: - new_type = UnionType.make_simplified_union(args) + for args_list in new_args: + new_type = UnionType.make_simplified_union(args_list) union_count += 1 if isinstance(new_type, UnionType) else 0 final_args.append(new_type) @@ -1256,7 +1273,7 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call return callables[0].copy_modified( arg_types=final_args, - ret_type=UnionType.make_simplified_union([t.ret_type for t in callables]), + ret_type=UnionType.make_simplified_union(new_returns), implicit=True, from_overloads=True) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index b4d1e07347d2..3294eee7771f 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -1444,52 +1444,52 @@ class Child4(Parent): [case testOverloadWithIncompatibleMethodOverrideAndImplementation] from typing import overload, Union, Any -class StrSub: pass +class Sub: pass +class A: pass +class B: pass class ParentWithTypedImpl: @overload - def f(self, arg: int) -> int: ... + def f(self, arg: A) -> A: ... @overload - def f(self, arg: str) -> str: ... - def f(self, arg: Union[int, str]) -> Union[int, str]: ... + def f(self, arg: B) -> B: ... + def f(self, arg: Union[A, B]) -> Union[A, B]: ... class Child1(ParentWithTypedImpl): @overload # E: Signature of "f" incompatible with supertype "ParentWithTypedImpl" - def f(self, arg: int) -> int: ... + def f(self, arg: A) -> A: ... @overload - def f(self, arg: StrSub) -> str: ... - def f(self, arg: Union[int, StrSub]) -> Union[int, str]: ... + def f(self, arg: Sub) -> B: ... + def f(self, arg: Union[A, Sub]) -> Union[A, B]: ... class Child2(ParentWithTypedImpl): @overload # E: Signature of "f" incompatible with supertype "ParentWithTypedImpl" - def f(self, arg: int) -> int: ... + def f(self, arg: A) -> A: ... @overload - def f(self, arg: StrSub) -> str: ... + def f(self, arg: Sub) -> B: ... def f(self, arg: Any) -> Any: ... class ParentWithDynamicImpl: @overload - def f(self, arg: int) -> int: ... + def f(self, arg: A) -> A: ... @overload - def f(self, arg: str) -> str: ... + def f(self, arg: B) -> B: ... def f(self, arg: Any) -> Any: ... class Child3(ParentWithDynamicImpl): @overload # E: Signature of "f" incompatible with supertype "ParentWithDynamicImpl" - def f(self, arg: int) -> int: ... + def f(self, arg: A) -> A: ... @overload - def f(self, arg: StrSub) -> str: ... - def f(self, arg: Union[int, StrSub]) -> Union[int, str]: ... + def f(self, arg: Sub) -> B: ... + def f(self, arg: Union[A, Sub]) -> Union[A, B]: ... class Child4(ParentWithDynamicImpl): @overload # E: Signature of "f" incompatible with supertype "ParentWithDynamicImpl" - def f(self, arg: int) -> int: ... + def f(self, arg: A) -> A: ... @overload - def f(self, arg: StrSub) -> str: ... + def f(self, arg: Sub) -> B: ... def f(self, arg: Any) -> Any: ... -[builtins fixtures/tuple.pyi] - [case testOverloadInferUnionReturnBasic] from typing import overload, Union @@ -1515,8 +1515,6 @@ def f2(x): ... reveal_type(f2(arg1)) # E: Revealed type is '__main__.B' -[builtins fixtures/tuple.pyi] - [case testOverloadInferUnionReturnMultipleArguments] from typing import overload, Union @@ -1544,7 +1542,6 @@ def f2(x, y): ... reveal_type(f2(arg1, arg1)) reveal_type(f2(arg1, C())) -[builtins fixtures/tuple.pyi] [out] main:16: error: Revealed type is '__main__.B' main:16: error: Argument 1 to "f1" has incompatible type "Union[A, C]"; expected "A" @@ -1553,6 +1550,24 @@ main:24: error: Revealed type is 'Union[__main__.B, __main__.D]' main:24: error: Argument 2 to "f2" has incompatible type "Union[A, C]"; expected "C" main:25: error: Revealed type is 'Union[__main__.B, __main__.D]' +[case testOverloadInferUnionSkipIfParameterNamesAreDifferent] +from typing import overload, Union + +class A: ... +class B: ... +class C: ... + +@overload +def f(x: A) -> B: ... +@overload +def f(y: B) -> C: ... +def f(x): ... + +x: Union[A, B] +reveal_type(f(A())) # E: Revealed type is '__main__.B' +reveal_type(f(B())) # E: Revealed type is '__main__.C' +f(x) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "A" + [case testOverloadInferUnionReturnFunctionsWithKwargs] from typing import overload, Union, Optional @@ -1584,3 +1599,50 @@ main:19: error: Revealed type is '__main__.C' main:19: error: Argument 2 to "f" has incompatible type "Union[B, C]"; expected "Optional[B]" main:21: error: Revealed type is '__main__.A' +[case testOverloadingInferUnionReturnWithTypevarWithValueRestriction] +from typing import overload, Union, TypeVar, Generic + +class A: pass +class B: pass +class C: pass + +T = TypeVar('T', B, C) + +class Wrapper(Generic[T]): + @overload + def f(self, x: T) -> B: ... + + @overload + def f(self, x: A) -> C: ... + + def f(self, x): ... + +obj: Wrapper[B] = Wrapper() +x: Union[A, B] + +reveal_type(obj.f(A())) # E: Revealed type is '__main__.C' +reveal_type(obj.f(B())) # E: Revealed type is '__main__.B' +reveal_type(obj.f(x)) # E: Revealed type is 'Union[__main__.B, __main__.C]' + +[case testOverloadingInferUnionReturnWithTypevarReturn] +from typing import overload, Union, TypeVar, Generic + +T = TypeVar('T') + +class Wrapper1(Generic[T]): pass +class Wrapper2(Generic[T]): pass +class A: pass +class B: pass + +@overload +def f(x: Wrapper1[T]) -> T: ... +@overload +def f(x: Wrapper2[T]) -> T: ... +def f(x): ... + +obj1: Union[Wrapper1[A], Wrapper2[A]] +reveal_type(f(obj1)) # E: Revealed type is '__main__.A' + +obj2: Union[Wrapper1[A], Wrapper2[B]] +reveal_type(f(obj2)) # E: Revealed type is 'Union[__main__.A, __main__.B]' +