From da33c20467708623bff691bafffd3fd214a1144b Mon Sep 17 00:00:00 2001 From: sobolevn Date: Wed, 14 Aug 2024 09:41:51 +0300 Subject: [PATCH 1/2] Infer correct types with overloads of `Type[Guard | Is]` --- mypy/checker.py | 22 ++++- mypy/checkexpr.py | 83 ++++++++++++++++--- test-data/unit/check-typeguard.test | 56 +++++++++++++ test-data/unit/check-typeis.test | 119 ++++++++++++++++++++++++++++ 4 files changed, 266 insertions(+), 14 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index db65660bbfbd..9cb14ff00e61 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5879,15 +5879,29 @@ def find_isinstance_check_helper( # considered "always right" (i.e. even if the types are not overlapping). # Also note that a care must be taken to unwrap this back at read places # where we use this to narrow down declared type. - if node.callee.type_guard is not None: - return {expr: TypeGuardedType(node.callee.type_guard)}, {} + with self.msg.filter_errors(), self.local_type_map(): + _, real_func = self.expr_checker.check_call( + get_proper_type(self.lookup_type(node.callee)), + node.args, + node.arg_kinds, + node, + node.arg_names, + ) + real_func = get_proper_type(real_func) + if not isinstance(real_func, CallableType) or not ( + real_func.type_guard or real_func.type_is + ): + return {}, {} + + if real_func.type_guard is not None: + return {expr: TypeGuardedType(real_func.type_guard)}, {} else: - assert node.callee.type_is is not None + assert real_func.type_is is not None return conditional_types_to_typemaps( expr, *self.conditional_types_with_intersection( self.lookup_type(expr), - [TypeRange(node.callee.type_is, is_upper_bound=False)], + [TypeRange(real_func.type_is, is_upper_bound=False)], expr, ), ) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9dee743ad406..1e5b8b218bfe 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2876,16 +2876,37 @@ def infer_overload_return_type( elif all_same_types([erase_type(typ) for typ in return_types]): self.chk.store_types(type_maps[0]) return erase_type(return_types[0]), erase_type(inferred_types[0]) - else: - return self.check_call( - callee=AnyType(TypeOfAny.special_form), - args=args, - arg_kinds=arg_kinds, - arg_names=arg_names, - context=context, - callable_name=callable_name, - object_type=object_type, - ) + return self.check_call( + callee=AnyType(TypeOfAny.special_form), + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) + elif not all_same_type_narrowers(matches): + # This is an example of how overloads can be: + # + # @overload + # def is_int(obj: float) -> TypeGuard[float]: ... + # @overload + # def is_int(obj: int) -> TypeGuard[int]: ... + # + # x: Any + # if is_int(x): + # reveal_type(x) # N: int | float + # + # So, we need to check that special case. + return self.check_call( + callee=self.combine_function_signatures(cast("list[ProperType]", matches)), + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) else: # Success! No ambiguity; return the first match. self.chk.store_types(type_maps[0]) @@ -3100,6 +3121,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call new_args: list[list[Type]] = [[] for _ in range(len(callables[0].arg_types))] new_kinds = list(callables[0].arg_kinds) new_returns: list[Type] = [] + new_type_guards: list[Type] = [] + new_type_narrowers: list[Type] = [] too_complex = False for target in callables: @@ -3126,8 +3149,25 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call for i, arg in enumerate(target.arg_types): new_args[i].append(arg) new_returns.append(target.ret_type) + if target.type_guard: + new_type_guards.append(target.type_guard) + if target.type_is: + new_type_narrowers.append(target.type_is) + + if new_type_guards and new_type_narrowers: + # They cannot be definined at the same time, + # declaring this function as too complex! + too_complex = True + union_type_guard = None + union_type_is = None + else: + union_type_guard = make_simplified_union(new_type_guards) if new_type_guards else None + union_type_is = ( + make_simplified_union(new_type_narrowers) if new_type_narrowers else None + ) union_return = make_simplified_union(new_returns) + if too_complex: any = AnyType(TypeOfAny.special_form) return callables[0].copy_modified( @@ -3137,6 +3177,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call ret_type=union_return, variables=variables, implicit=True, + type_guard=union_type_guard, + type_is=union_type_is, ) final_args = [] @@ -3150,6 +3192,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call ret_type=union_return, variables=variables, implicit=True, + type_guard=union_type_guard, + type_is=union_type_is, ) def erased_signature_similarity( @@ -6464,6 +6508,25 @@ def all_same_types(types: list[Type]) -> bool: return all(is_same_type(t, types[0]) for t in types[1:]) +def all_same_type_narrowers(types: list[CallableType]) -> bool: + if not types: + return True + + type_guards: list[Type] = [] + type_narrowers: list[Type] = [] + + for typ in types: + if typ.type_guard: + type_guards.append(typ.type_guard) + if typ.type_is: + type_narrowers.append(typ.type_is) + if type_guards and type_narrowers: + # Some overloads declare `TypeGuard` and some declare `TypeIs`, + # we cannot handle this in a union. + return False + return all_same_types(type_guards) and all_same_types(type_narrowers) + + def merge_typevars_in_callables_by_name( callables: Sequence[CallableType], ) -> tuple[list[CallableType], list[TypeVarType]]: diff --git a/test-data/unit/check-typeguard.test b/test-data/unit/check-typeguard.test index 27b88553fb43..9ae9938c7485 100644 --- a/test-data/unit/check-typeguard.test +++ b/test-data/unit/check-typeguard.test @@ -721,3 +721,59 @@ x: object assert a(x=x) reveal_type(x) # N: Revealed type is "builtins.int" [builtins fixtures/tuple.pyi] + +[case testTypeGuardInOverloads] +from typing import Any, overload, Union +from typing_extensions import TypeGuard + +@overload +def func1(x: str) -> TypeGuard[str]: + ... + +@overload +def func1(x: int) -> TypeGuard[int]: + ... + +def func1(x: Any) -> Any: + return True + +def func2(val: Any): + if func1(val): + reveal_type(val) # N: Revealed type is "Union[builtins.str, builtins.int]" + else: + reveal_type(val) # N: Revealed type is "Any" + +def func3(val: Union[int, str]): + if func1(val): + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" + else: + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" + +def func4(val: int): + if func1(val): + reveal_type(val) # N: Revealed type is "builtins.int" + else: + reveal_type(val) # N: Revealed type is "builtins.int" +[builtins fixtures/tuple.pyi] + +[case testTypeIsInOverloadsSameReturn] +from typing import Any, overload, Union +from typing_extensions import TypeGuard + +@overload +def func1(x: str) -> TypeGuard[str]: + ... + +@overload +def func1(x: int) -> TypeGuard[str]: + ... + +def func1(x: Any) -> Any: + return True + +def func2(val: Union[int, str]): + if func1(val): + reveal_type(val) # N: Revealed type is "builtins.str" + else: + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-typeis.test b/test-data/unit/check-typeis.test index 6b96845504ab..110926b511f9 100644 --- a/test-data/unit/check-typeis.test +++ b/test-data/unit/check-typeis.test @@ -808,3 +808,122 @@ accept_typeguard(typeis) # E: Argument 1 to "accept_typeguard" has incompatible accept_typeguard(typeguard) [builtins fixtures/tuple.pyi] + +[case testTypeIsInOverloads] +from typing import Any, overload, Union +from typing_extensions import TypeIs + +@overload +def func1(x: str) -> TypeIs[str]: + ... + +@overload +def func1(x: int) -> TypeIs[int]: + ... + +def func1(x: Any) -> Any: + return True + +def func2(val: Any): + if func1(val): + reveal_type(val) # N: Revealed type is "Union[builtins.str, builtins.int]" + else: + reveal_type(val) # N: Revealed type is "Any" + +def func3(val: Union[int, str]): + if func1(val): + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" + else: + reveal_type(val) + +def func4(val: int): + if func1(val): + reveal_type(val) # N: Revealed type is "builtins.int" + else: + reveal_type(val) +[builtins fixtures/tuple.pyi] + +[case testTypeIsInOverloadsSameReturn] +from typing import Any, overload, Union +from typing_extensions import TypeIs + +@overload +def func1(x: str) -> TypeIs[str]: + ... + +@overload +def func1(x: int) -> TypeIs[str]: # type: ignore + ... + +def func1(x: Any) -> Any: + return True + +def func2(val: Union[int, str]): + if func1(val): + reveal_type(val) # N: Revealed type is "builtins.str" + else: + reveal_type(val) # N: Revealed type is "builtins.int" +[builtins fixtures/tuple.pyi] + +[case testTypeIsInOverloadsUnionizeError] +from typing import Any, overload, Union +from typing_extensions import TypeIs, TypeGuard + +@overload +def func1(x: str) -> TypeIs[str]: + ... + +@overload +def func1(x: int) -> TypeGuard[int]: + ... + +def func1(x: Any) -> Any: + return True + +def func2(val: Union[int, str]): + if func1(val): + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" + else: + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" +[builtins fixtures/tuple.pyi] + +[case testTypeIsInOverloadsUnionizeError2] +from typing import Any, overload, Union +from typing_extensions import TypeIs, TypeGuard + +@overload +def func1(x: int) -> TypeGuard[int]: + ... + +@overload +def func1(x: str) -> TypeIs[str]: + ... + +def func1(x: Any) -> Any: + return True + +def func2(val: Union[int, str]): + if func1(val): + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" + else: + reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]" +[builtins fixtures/tuple.pyi] + +[case testTypeIsLikeIsDataclass] +from typing import Any, overload, Union, Type +from typing_extensions import TypeIs + +class DataclassInstance: ... + +@overload +def is_dataclass(obj: type) -> TypeIs[Type[DataclassInstance]]: ... +@overload +def is_dataclass(obj: object) -> TypeIs[Union[DataclassInstance, Type[DataclassInstance]]]: ... + +def is_dataclass(obj: Union[type, object]) -> bool: + return False + +def func(arg: Any) -> None: + if is_dataclass(arg): + reveal_type(arg) # N: Revealed type is "Union[Type[__main__.DataclassInstance], __main__.DataclassInstance]" +[builtins fixtures/tuple.pyi] From 50377fab1b8dd0ca5805587d5054474bcfac0101 Mon Sep 17 00:00:00 2001 From: sobolevn Date: Fri, 17 Jan 2025 18:47:54 +0300 Subject: [PATCH 2/2] Address review --- mypy/checker.py | 2 ++ mypy/checkexpr.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 9e7de85d956f..5829b31447fe 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6037,6 +6037,8 @@ def find_isinstance_check_helper( # Also note that a care must be taken to unwrap this back at read places # where we use this to narrow down declared type. with self.msg.filter_errors(), self.local_type_map(): + # `node.callee` can be an `overload`ed function, + # we need to resolve the real `overload` case. _, real_func = self.expr_checker.check_call( get_proper_type(self.lookup_type(node.callee)), node.args, diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d280cc7093b3..a10dc00bb1de 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -6565,7 +6565,7 @@ def all_same_types(types: list[Type]) -> bool: def all_same_type_narrowers(types: list[CallableType]) -> bool: - if not types: + if len(types) <= 1: return True type_guards: list[Type] = []