diff --git a/mypy/checker.py b/mypy/checker.py index 47b08b683e36..5829b31447fe 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6036,15 +6036,31 @@ def find_isinstance_check_helper( # considered "always right" (i.e. even if the types are not overlapping). # Also note that a care must be taken to unwrap this back at read places # where we use this to narrow down declared type. - if node.callee.type_guard is not None: - return {expr: TypeGuardedType(node.callee.type_guard)}, {} + with self.msg.filter_errors(), self.local_type_map(): + # `node.callee` can be an `overload`ed function, + # we need to resolve the real `overload` case. + _, real_func = self.expr_checker.check_call( + get_proper_type(self.lookup_type(node.callee)), + node.args, + node.arg_kinds, + node, + node.arg_names, + ) + real_func = get_proper_type(real_func) + if not isinstance(real_func, CallableType) or not ( + real_func.type_guard or real_func.type_is + ): + return {}, {} + + if real_func.type_guard is not None: + return {expr: TypeGuardedType(real_func.type_guard)}, {} else: - assert node.callee.type_is is not None + assert real_func.type_is is not None return conditional_types_to_typemaps( expr, *self.conditional_types_with_intersection( self.lookup_type(expr), - [TypeRange(node.callee.type_is, is_upper_bound=False)], + [TypeRange(real_func.type_is, is_upper_bound=False)], expr, ), ) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index b6618109bb44..a10dc00bb1de 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2906,16 +2906,37 @@ def infer_overload_return_type( elif all_same_types([erase_type(typ) for typ in return_types]): self.chk.store_types(type_maps[0]) return erase_type(return_types[0]), erase_type(inferred_types[0]) - else: - return self.check_call( - callee=AnyType(TypeOfAny.special_form), - args=args, - arg_kinds=arg_kinds, - arg_names=arg_names, - context=context, - callable_name=callable_name, - object_type=object_type, - ) + return self.check_call( + callee=AnyType(TypeOfAny.special_form), + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) + elif not all_same_type_narrowers(matches): + # This is an example of how overloads can be: + # + # @overload + # def is_int(obj: float) -> TypeGuard[float]: ... + # @overload + # def is_int(obj: int) -> TypeGuard[int]: ... + # + # x: Any + # if is_int(x): + # reveal_type(x) # N: int | float + # + # So, we need to check that special case. + return self.check_call( + callee=self.combine_function_signatures(cast("list[ProperType]", matches)), + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) else: # Success! No ambiguity; return the first match. self.chk.store_types(type_maps[0]) @@ -3130,6 +3151,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call new_args: list[list[Type]] = [[] for _ in range(len(callables[0].arg_types))] new_kinds = list(callables[0].arg_kinds) new_returns: list[Type] = [] + new_type_guards: list[Type] = [] + new_type_narrowers: list[Type] = [] too_complex = False for target in callables: @@ -3156,8 +3179,25 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call for i, arg in enumerate(target.arg_types): new_args[i].append(arg) new_returns.append(target.ret_type) + if target.type_guard: + new_type_guards.append(target.type_guard) + if target.type_is: + new_type_narrowers.append(target.type_is) + + if new_type_guards and new_type_narrowers: + # They cannot be definined at the same time, + # declaring this function as too complex! + too_complex = True + union_type_guard = None + union_type_is = None + else: + union_type_guard = make_simplified_union(new_type_guards) if new_type_guards else None + union_type_is = ( + make_simplified_union(new_type_narrowers) if new_type_narrowers else None + ) union_return = make_simplified_union(new_returns) + if too_complex: any = AnyType(TypeOfAny.special_form) return callables[0].copy_modified( @@ -3167,6 +3207,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call ret_type=union_return, variables=variables, implicit=True, + type_guard=union_type_guard, + type_is=union_type_is, ) final_args = [] @@ -3180,6 +3222,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call ret_type=union_return, variables=variables, implicit=True, + type_guard=union_type_guard, + type_is=union_type_is, ) def erased_signature_similarity( @@ -6520,6 +6564,25 @@ def all_same_types(types: list[Type]) -> bool: return all(is_same_type(t, types[0]) for t in types[1:]) +def all_same_type_narrowers(types: list[CallableType]) -> bool: + if len(types) <= 1: + return True + + type_guards: list[Type] = [] + type_narrowers: list[Type] = [] + + for typ in types: + if typ.type_guard: + type_guards.append(typ.type_guard) + if typ.type_is: + type_narrowers.append(typ.type_is) + if type_guards and type_narrowers: + # Some overloads declare `TypeGuard` and some declare `TypeIs`, + # we cannot handle this in a union. + return False + return all_same_types(type_guards) and all_same_types(type_narrowers) + + def merge_typevars_in_callables_by_name( callables: Sequence[CallableType], ) -> tuple[list[CallableType], list[TypeVarType]]: diff --git a/test-data/unit/check-typeguard.test b/test-data/unit/check-typeguard.test index e7a8eac4f043..eff3ce068cc7 100644 --- a/test-data/unit/check-typeguard.test +++ b/test-data/unit/check-typeguard.test @@ -730,3 +730,59 @@ x: object assert a(x=x) reveal_type(x) # N: Revealed type is "builtins.int" [builtins fixtures/tuple.pyi] + +[case testTypeGuardInOverloads] +from typing import Any, overload, Union +from typing_extensions import TypeGuard + +@overload +def func1(x: str) -> TypeGuard[str]: + ... + +@overload +def func1(x: int) -> TypeGuard[int]: + ... + +def func1(x: Any) -> Any: + return True + +def func2(val: Any): + if func1(val): + reveal_type(val) # N: Revealed type is "Union[builtins.str, builtins.int]" + else: + reveal_type(val) # N: Revealed type is "Any" + +def func3(val: Union[int, str]): + if func1(val): + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" + else: + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" + +def func4(val: int): + if func1(val): + reveal_type(val) # N: Revealed type is "builtins.int" + else: + reveal_type(val) # N: Revealed type is "builtins.int" +[builtins fixtures/tuple.pyi] + +[case testTypeIsInOverloadsSameReturn] +from typing import Any, overload, Union +from typing_extensions import TypeGuard + +@overload +def func1(x: str) -> TypeGuard[str]: + ... + +@overload +def func1(x: int) -> TypeGuard[str]: + ... + +def func1(x: Any) -> Any: + return True + +def func2(val: Union[int, str]): + if func1(val): + reveal_type(val) # N: Revealed type is "builtins.str" + else: + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-typeis.test b/test-data/unit/check-typeis.test index 2372f990fda1..7d1754bf8340 100644 --- a/test-data/unit/check-typeis.test +++ b/test-data/unit/check-typeis.test @@ -817,3 +817,122 @@ accept_typeguard(typeis) # E: Argument 1 to "accept_typeguard" has incompatible accept_typeguard(typeguard) [builtins fixtures/tuple.pyi] + +[case testTypeIsInOverloads] +from typing import Any, overload, Union +from typing_extensions import TypeIs + +@overload +def func1(x: str) -> TypeIs[str]: + ... + +@overload +def func1(x: int) -> TypeIs[int]: + ... + +def func1(x: Any) -> Any: + return True + +def func2(val: Any): + if func1(val): + reveal_type(val) # N: Revealed type is "Union[builtins.str, builtins.int]" + else: + reveal_type(val) # N: Revealed type is "Any" + +def func3(val: Union[int, str]): + if func1(val): + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" + else: + reveal_type(val) + +def func4(val: int): + if func1(val): + reveal_type(val) # N: Revealed type is "builtins.int" + else: + reveal_type(val) +[builtins fixtures/tuple.pyi] + +[case testTypeIsInOverloadsSameReturn] +from typing import Any, overload, Union +from typing_extensions import TypeIs + +@overload +def func1(x: str) -> TypeIs[str]: + ... + +@overload +def func1(x: int) -> TypeIs[str]: # type: ignore + ... + +def func1(x: Any) -> Any: + return True + +def func2(val: Union[int, str]): + if func1(val): + reveal_type(val) # N: Revealed type is "builtins.str" + else: + reveal_type(val) # N: Revealed type is "builtins.int" +[builtins fixtures/tuple.pyi] + +[case testTypeIsInOverloadsUnionizeError] +from typing import Any, overload, Union +from typing_extensions import TypeIs, TypeGuard + +@overload +def func1(x: str) -> TypeIs[str]: + ... + +@overload +def func1(x: int) -> TypeGuard[int]: + ... + +def func1(x: Any) -> Any: + return True + +def func2(val: Union[int, str]): + if func1(val): + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" + else: + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" +[builtins fixtures/tuple.pyi] + +[case testTypeIsInOverloadsUnionizeError2] +from typing import Any, overload, Union +from typing_extensions import TypeIs, TypeGuard + +@overload +def func1(x: int) -> TypeGuard[int]: + ... + +@overload +def func1(x: str) -> TypeIs[str]: + ... + +def func1(x: Any) -> Any: + return True + +def func2(val: Union[int, str]): + if func1(val): + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" + else: + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" +[builtins fixtures/tuple.pyi] + +[case testTypeIsLikeIsDataclass] +from typing import Any, overload, Union, Type +from typing_extensions import TypeIs + +class DataclassInstance: ... + +@overload +def is_dataclass(obj: type) -> TypeIs[Type[DataclassInstance]]: ... +@overload +def is_dataclass(obj: object) -> TypeIs[Union[DataclassInstance, Type[DataclassInstance]]]: ... + +def is_dataclass(obj: Union[type, object]) -> bool: + return False + +def func(arg: Any) -> None: + if is_dataclass(arg): + reveal_type(arg) # N: Revealed type is "Union[Type[__main__.DataclassInstance], __main__.DataclassInstance]" +[builtins fixtures/tuple.pyi]