Skip to content

Commit fb11c98

Browse files
authored
Fixes generic inference in functions with TypeGuard (#11797)
Fixes #11780, fixes #11428
1 parent 49d5cc9 commit fb11c98

File tree

3 files changed

+70
-9
lines changed

3 files changed

+70
-9
lines changed

mypy/applytype.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,18 @@ def apply_generic_arguments(
103103
# Apply arguments to argument types.
104104
arg_types = [expand_type(at, id_to_type) for at in callable.arg_types]
105105

106+
# Apply arguments to TypeGuard if any.
107+
if callable.type_guard is not None:
108+
type_guard = expand_type(callable.type_guard, id_to_type)
109+
else:
110+
type_guard = None
111+
106112
# The callable may retain some type vars if only some were applied.
107113
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]
108114

109115
return callable.copy_modified(
110116
arg_types=arg_types,
111117
ret_type=expand_type(callable.ret_type, id_to_type),
112118
variables=remaining_tvars,
119+
type_guard=type_guard,
113120
)

mypy/checkexpr.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,6 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
344344
ret_type=self.object_type(),
345345
fallback=self.named_type('builtins.function'))
346346
callee_type = get_proper_type(self.accept(e.callee, type_context, always_allow_any=True))
347-
if (isinstance(e.callee, RefExpr)
348-
and isinstance(callee_type, CallableType)
349-
and callee_type.type_guard is not None):
350-
# Cache it for find_isinstance_check()
351-
e.callee.type_guard = callee_type.type_guard
352347
if (self.chk.options.disallow_untyped_calls and
353348
self.chk.in_checked_function() and
354349
isinstance(callee_type, CallableType)
@@ -886,10 +881,19 @@ def check_call_expr_with_callee_type(self,
886881
# Unions are special-cased to allow plugins to act on each item in the union.
887882
elif member is not None and isinstance(object_type, UnionType):
888883
return self.check_union_call_expr(e, object_type, member)
889-
return self.check_call(callee_type, e.args, e.arg_kinds, e,
890-
e.arg_names, callable_node=e.callee,
891-
callable_name=callable_name,
892-
object_type=object_type)[0]
884+
ret_type, callee_type = self.check_call(
885+
callee_type, e.args, e.arg_kinds, e,
886+
e.arg_names, callable_node=e.callee,
887+
callable_name=callable_name,
888+
object_type=object_type,
889+
)
890+
proper_callee = get_proper_type(callee_type)
891+
if (isinstance(e.callee, RefExpr)
892+
and isinstance(proper_callee, CallableType)
893+
and proper_callee.type_guard is not None):
894+
# Cache it for find_isinstance_check()
895+
e.callee.type_guard = proper_callee.type_guard
896+
return ret_type
893897

894898
def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str) -> Type:
895899
""""Type check calling a member expression where the base type is a union."""

test-data/unit/check-typeguard.test

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,3 +547,53 @@ accepts_typeguard(with_typeguard_a) # E: Argument 1 to "accepts_typeguard" has
547547
accepts_typeguard(with_typeguard_b)
548548
accepts_typeguard(with_typeguard_c)
549549
[builtins fixtures/tuple.pyi]
550+
551+
[case testTypeGuardWithIdentityGeneric]
552+
from typing import TypeVar
553+
from typing_extensions import TypeGuard
554+
555+
_T = TypeVar("_T")
556+
557+
def identity(val: _T) -> TypeGuard[_T]:
558+
pass
559+
560+
def func1(name: _T):
561+
reveal_type(name) # N: Revealed type is "_T`-1"
562+
if identity(name):
563+
reveal_type(name) # N: Revealed type is "_T`-1"
564+
565+
def func2(name: str):
566+
reveal_type(name) # N: Revealed type is "builtins.str"
567+
if identity(name):
568+
reveal_type(name) # N: Revealed type is "builtins.str"
569+
[builtins fixtures/tuple.pyi]
570+
571+
[case testTypeGuardWithGenericInstance]
572+
from typing import TypeVar, List
573+
from typing_extensions import TypeGuard
574+
575+
_T = TypeVar("_T")
576+
577+
def is_list_of_str(val: _T) -> TypeGuard[List[_T]]:
578+
pass
579+
580+
def func(name: str):
581+
reveal_type(name) # N: Revealed type is "builtins.str"
582+
if is_list_of_str(name):
583+
reveal_type(name) # N: Revealed type is "builtins.list[builtins.str]"
584+
[builtins fixtures/tuple.pyi]
585+
586+
[case testTypeGuardWithTupleGeneric]
587+
from typing import TypeVar, Tuple
588+
from typing_extensions import TypeGuard
589+
590+
_T = TypeVar("_T")
591+
592+
def is_two_element_tuple(val: Tuple[_T, ...]) -> TypeGuard[Tuple[_T, _T]]:
593+
pass
594+
595+
def func(names: Tuple[str, ...]):
596+
reveal_type(names) # N: Revealed type is "builtins.tuple[builtins.str, ...]"
597+
if is_two_element_tuple(names):
598+
reveal_type(names) # N: Revealed type is "Tuple[builtins.str, builtins.str]"
599+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)