Skip to content

Commit 28c67cb

Browse files
authored
More helpful type guards (#14238)
Fixes #13199 Refs #14425
1 parent 7c14eba commit 28c67cb

File tree

4 files changed

+142
-10
lines changed

4 files changed

+142
-10
lines changed

mypy/checker.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5350,10 +5350,26 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
53505350
return self.hasattr_type_maps(expr, self.lookup_type(expr), attr[0])
53515351
elif isinstance(node.callee, RefExpr):
53525352
if node.callee.type_guard is not None:
5353-
# TODO: Follow keyword args or *args, **kwargs
5353+
# TODO: Follow *args, **kwargs
53545354
if node.arg_kinds[0] != nodes.ARG_POS:
5355-
self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node)
5356-
return {}, {}
5355+
# the first argument might be used as a kwarg
5356+
called_type = get_proper_type(self.lookup_type(node.callee))
5357+
assert isinstance(called_type, (CallableType, Overloaded))
5358+
5359+
# *assuming* the overloaded function is correct, there's a couple cases:
5360+
# 1) The first argument has different names, but is pos-only. We don't
5361+
# care about this case, the argument must be passed positionally.
5362+
# 2) The first argument allows keyword reference, therefore must be the
5363+
# same between overloads.
5364+
name = called_type.items[0].arg_names[0]
5365+
5366+
if name in node.arg_names:
5367+
idx = node.arg_names.index(name)
5368+
# we want the idx-th variable to be narrowed
5369+
expr = collapse_walrus(node.args[idx])
5370+
else:
5371+
self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node)
5372+
return {}, {}
53575373
if literal(expr) == LITERAL_TYPE:
53585374
# Note: we wrap the target type, so that we can special case later.
53595375
# Namely, for isinstance() we use a normal meet, while TypeGuard is

mypy/semanal.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,20 @@ def analyze_func_def(self, defn: FuncDef) -> None:
864864
return
865865
assert isinstance(result, ProperType)
866866
if isinstance(result, CallableType):
867+
# type guards need to have a positional argument, to spec
868+
if (
869+
result.type_guard
870+
and ARG_POS not in result.arg_kinds[self.is_class_scope() :]
871+
and not defn.is_static
872+
):
873+
self.fail(
874+
"TypeGuard functions must have a positional argument",
875+
result,
876+
code=codes.VALID_TYPE,
877+
)
878+
# in this case, we just kind of just ... remove the type guard.
879+
result = result.copy_modified(type_guard=None)
880+
867881
result = self.remove_unpack_kwargs(defn, result)
868882
if has_self_type and self.type is not None:
869883
info = self.type

test-data/unit/check-python38.test

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,34 @@ class C(Generic[T]):
735735
main:10: note: Revealed type is "builtins.int"
736736
main:10: note: Revealed type is "builtins.str"
737737

738+
[case testTypeGuardWithPositionalOnlyArg]
739+
# flags: --python-version 3.8
740+
from typing_extensions import TypeGuard
741+
742+
def typeguard(x: object, /) -> TypeGuard[int]:
743+
...
744+
745+
n: object
746+
if typeguard(n):
747+
reveal_type(n)
748+
[builtins fixtures/tuple.pyi]
749+
[out]
750+
main:9: note: Revealed type is "builtins.int"
751+
752+
[case testTypeGuardKeywordFollowingWalrus]
753+
# flags: --python-version 3.8
754+
from typing import cast
755+
from typing_extensions import TypeGuard
756+
757+
def typeguard(x: object) -> TypeGuard[int]:
758+
...
759+
760+
if typeguard(x=(n := cast(object, "hi"))):
761+
reveal_type(n)
762+
[builtins fixtures/tuple.pyi]
763+
[out]
764+
main:9: note: Revealed type is "builtins.int"
765+
738766
[case testNoCrashOnAssignmentExprClass]
739767
class C:
740768
[(j := i) for i in [1, 2, 3]] # E: Assignment expression within a comprehension cannot be used in a class body

test-data/unit/check-typeguard.test

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ reveal_type(foo) # N: Revealed type is "def (a: builtins.object) -> TypeGuard[b
3737
[case testTypeGuardCallArgsNone]
3838
from typing_extensions import TypeGuard
3939
class Point: pass
40-
# TODO: error on the 'def' line (insufficient args for type guard)
41-
def is_point() -> TypeGuard[Point]: pass
40+
41+
def is_point() -> TypeGuard[Point]: pass # E: TypeGuard functions must have a positional argument
4242
def main(a: object) -> None:
4343
if is_point():
4444
reveal_type(a) # N: Revealed type is "builtins.object"
@@ -227,13 +227,13 @@ def main(a: object) -> None:
227227
from typing_extensions import TypeGuard
228228
def is_float(a: object, b: object = 0) -> TypeGuard[float]: pass
229229
def main1(a: object) -> None:
230-
# This is debatable -- should we support these cases?
230+
if is_float(a=a, b=1):
231+
reveal_type(a) # N: Revealed type is "builtins.float"
231232

232-
if is_float(a=a, b=1): # E: Type guard requires positional argument
233-
reveal_type(a) # N: Revealed type is "builtins.object"
233+
if is_float(b=1, a=a):
234+
reveal_type(a) # N: Revealed type is "builtins.float"
234235

235-
if is_float(b=1, a=a): # E: Type guard requires positional argument
236-
reveal_type(a) # N: Revealed type is "builtins.object"
236+
# This is debatable -- should we support these cases?
237237

238238
ta = (a,)
239239
if is_float(*ta): # E: Type guard requires positional argument
@@ -597,3 +597,77 @@ def func(names: Tuple[str, ...]):
597597
if is_two_element_tuple(names):
598598
reveal_type(names) # N: Revealed type is "Tuple[builtins.str, builtins.str]"
599599
[builtins fixtures/tuple.pyi]
600+
601+
[case testTypeGuardErroneousDefinitionFails]
602+
from typing_extensions import TypeGuard
603+
604+
class Z:
605+
def typeguard(self, *, x: object) -> TypeGuard[int]: # E: TypeGuard functions must have a positional argument
606+
...
607+
608+
def bad_typeguard(*, x: object) -> TypeGuard[int]: # E: TypeGuard functions must have a positional argument
609+
...
610+
[builtins fixtures/tuple.pyi]
611+
612+
[case testTypeGuardWithKeywordArg]
613+
from typing_extensions import TypeGuard
614+
615+
class Z:
616+
def typeguard(self, x: object) -> TypeGuard[int]:
617+
...
618+
619+
def typeguard(x: object) -> TypeGuard[int]:
620+
...
621+
622+
n: object
623+
if typeguard(x=n):
624+
reveal_type(n) # N: Revealed type is "builtins.int"
625+
626+
if Z().typeguard(x=n):
627+
reveal_type(n) # N: Revealed type is "builtins.int"
628+
[builtins fixtures/tuple.pyi]
629+
630+
[case testStaticMethodTypeGuard]
631+
from typing_extensions import TypeGuard
632+
633+
class Y:
634+
@staticmethod
635+
def typeguard(h: object) -> TypeGuard[int]:
636+
...
637+
638+
x: object
639+
if Y().typeguard(x):
640+
reveal_type(x) # N: Revealed type is "builtins.int"
641+
if Y.typeguard(x):
642+
reveal_type(x) # N: Revealed type is "builtins.int"
643+
[builtins fixtures/tuple.pyi]
644+
[builtins fixtures/classmethod.pyi]
645+
646+
[case testTypeGuardKwargFollowingThroughOverloaded]
647+
from typing import overload, Union
648+
from typing_extensions import TypeGuard
649+
650+
@overload
651+
def typeguard(x: object, y: str) -> TypeGuard[str]:
652+
...
653+
654+
@overload
655+
def typeguard(x: object, y: int) -> TypeGuard[int]:
656+
...
657+
658+
def typeguard(x: object, y: Union[int, str]) -> Union[TypeGuard[int], TypeGuard[str]]:
659+
...
660+
661+
x: object
662+
if typeguard(x=x, y=42):
663+
reveal_type(x) # N: Revealed type is "builtins.int"
664+
665+
if typeguard(y=42, x=x):
666+
reveal_type(x) # N: Revealed type is "builtins.int"
667+
668+
if typeguard(x=x, y="42"):
669+
reveal_type(x) # N: Revealed type is "builtins.str"
670+
671+
if typeguard(y="42", x=x):
672+
reveal_type(x) # N: Revealed type is "builtins.str"
673+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)