diff --git a/mypy/applytype.py b/mypy/applytype.py index 51a10c7084cf..b32b88fa3276 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -103,6 +103,12 @@ def apply_generic_arguments( # Apply arguments to argument types. arg_types = [expand_type(at, id_to_type) for at in callable.arg_types] + # Apply arguments to TypeGuard if any. + if callable.type_guard is not None: + type_guard = expand_type(callable.type_guard, id_to_type) + else: + type_guard = None + # The callable may retain some type vars if only some were applied. remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type] @@ -110,4 +116,5 @@ def apply_generic_arguments( arg_types=arg_types, ret_type=expand_type(callable.ret_type, id_to_type), variables=remaining_tvars, + type_guard=type_guard, ) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ed6fd73acfa5..9737c3585a76 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -344,11 +344,6 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) -> ret_type=self.object_type(), fallback=self.named_type('builtins.function')) callee_type = get_proper_type(self.accept(e.callee, type_context, always_allow_any=True)) - if (isinstance(e.callee, RefExpr) - and isinstance(callee_type, CallableType) - and callee_type.type_guard is not None): - # Cache it for find_isinstance_check() - e.callee.type_guard = callee_type.type_guard if (self.chk.options.disallow_untyped_calls and self.chk.in_checked_function() and isinstance(callee_type, CallableType) @@ -886,10 +881,19 @@ def check_call_expr_with_callee_type(self, # Unions are special-cased to allow plugins to act on each item in the union. elif member is not None and isinstance(object_type, UnionType): return self.check_union_call_expr(e, object_type, member) - return self.check_call(callee_type, e.args, e.arg_kinds, e, - e.arg_names, callable_node=e.callee, - callable_name=callable_name, - object_type=object_type)[0] + ret_type, callee_type = self.check_call( + callee_type, e.args, e.arg_kinds, e, + e.arg_names, callable_node=e.callee, + callable_name=callable_name, + object_type=object_type, + ) + proper_callee = get_proper_type(callee_type) + if (isinstance(e.callee, RefExpr) + and isinstance(proper_callee, CallableType) + and proper_callee.type_guard is not None): + # Cache it for find_isinstance_check() + e.callee.type_guard = proper_callee.type_guard + return ret_type def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str) -> Type: """"Type check calling a member expression where the base type is a union.""" diff --git a/test-data/unit/check-typeguard.test b/test-data/unit/check-typeguard.test index ecefce091405..64fc7ea695cb 100644 --- a/test-data/unit/check-typeguard.test +++ b/test-data/unit/check-typeguard.test @@ -547,3 +547,53 @@ accepts_typeguard(with_typeguard_a) # E: Argument 1 to "accepts_typeguard" has accepts_typeguard(with_typeguard_b) accepts_typeguard(with_typeguard_c) [builtins fixtures/tuple.pyi] + +[case testTypeGuardWithIdentityGeneric] +from typing import TypeVar +from typing_extensions import TypeGuard + +_T = TypeVar("_T") + +def identity(val: _T) -> TypeGuard[_T]: + pass + +def func1(name: _T): + reveal_type(name) # N: Revealed type is "_T`-1" + if identity(name): + reveal_type(name) # N: Revealed type is "_T`-1" + +def func2(name: str): + reveal_type(name) # N: Revealed type is "builtins.str" + if identity(name): + reveal_type(name) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi] + +[case testTypeGuardWithGenericInstance] +from typing import TypeVar, List +from typing_extensions import TypeGuard + +_T = TypeVar("_T") + +def is_list_of_str(val: _T) -> TypeGuard[List[_T]]: + pass + +def func(name: str): + reveal_type(name) # N: Revealed type is "builtins.str" + if is_list_of_str(name): + reveal_type(name) # N: Revealed type is "builtins.list[builtins.str]" +[builtins fixtures/tuple.pyi] + +[case testTypeGuardWithTupleGeneric] +from typing import TypeVar, Tuple +from typing_extensions import TypeGuard + +_T = TypeVar("_T") + +def is_two_element_tuple(val: Tuple[_T, ...]) -> TypeGuard[Tuple[_T, _T]]: + pass + +def func(names: Tuple[str, ...]): + reveal_type(names) # N: Revealed type is "builtins.tuple[builtins.str, ...]" + if is_two_element_tuple(names): + reveal_type(names) # N: Revealed type is "Tuple[builtins.str, builtins.str]" +[builtins fixtures/tuple.pyi]