Skip to content

Commit fffbe88

Browse files
authored
Support TypeGuard (PEP 647) (#9865)
PEP 647 is still in draft mode, but it is likely to be accepted, and this helps solve some real issues.
1 parent 734e4ad commit fffbe88

13 files changed

+408
-9
lines changed

mypy/checker.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef,
3636
is_named_instance, union_items, TypeQuery, LiteralType,
3737
is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType,
38-
get_proper_types, is_literal_type, TypeAliasType)
38+
get_proper_types, is_literal_type, TypeAliasType, TypeGuardType)
3939
from mypy.sametypes import is_same_type
4040
from mypy.messages import (
4141
MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq,
@@ -3957,6 +3957,7 @@ def find_isinstance_check(self, node: Expression
39573957
) -> Tuple[TypeMap, TypeMap]:
39583958
"""Find any isinstance checks (within a chain of ands). Includes
39593959
implicit and explicit checks for None and calls to callable.
3960+
Also includes TypeGuard functions.
39603961
39613962
Return value is a map of variables to their types if the condition
39623963
is true and a map of variables to their types if the condition is false.
@@ -4001,6 +4002,14 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
40014002
if literal(expr) == LITERAL_TYPE:
40024003
vartype = type_map[expr]
40034004
return self.conditional_callable_type_map(expr, vartype)
4005+
elif isinstance(node.callee, RefExpr):
4006+
if node.callee.type_guard is not None:
4007+
# TODO: Follow keyword args or *args, **kwargs
4008+
if node.arg_kinds[0] != nodes.ARG_POS:
4009+
self.fail("Type guard requires positional argument", node)
4010+
return {}, {}
4011+
if literal(expr) == LITERAL_TYPE:
4012+
return {expr: TypeGuardType(node.callee.type_guard)}, {}
40044013
elif isinstance(node, ComparisonExpr):
40054014
# Step 1: Obtain the types of each operand and whether or not we can
40064015
# narrow their types. (For example, we shouldn't try narrowing the

mypy/checkexpr.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
make_optional_type,
1515
)
1616
from mypy.types import (
17-
Type, AnyType, CallableType, Overloaded, NoneType, TypeVarDef,
17+
Type, AnyType, CallableType, Overloaded, NoneType, TypeGuardType, TypeVarDef,
1818
TupleType, TypedDictType, Instance, TypeVarType, ErasedType, UnionType,
1919
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue,
2020
is_named_instance, FunctionLike,
@@ -317,6 +317,11 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
317317
ret_type=self.object_type(),
318318
fallback=self.named_type('builtins.function'))
319319
callee_type = get_proper_type(self.accept(e.callee, type_context, always_allow_any=True))
320+
if (isinstance(e.callee, RefExpr)
321+
and isinstance(callee_type, CallableType)
322+
and callee_type.type_guard is not None):
323+
# Cache it for find_isinstance_check()
324+
e.callee.type_guard = callee_type.type_guard
320325
if (self.chk.options.disallow_untyped_calls and
321326
self.chk.in_checked_function() and
322327
isinstance(callee_type, CallableType)
@@ -4163,6 +4168,10 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type,
41634168
"""
41644169
if literal(expr) >= LITERAL_TYPE:
41654170
restriction = self.chk.binder.get(expr)
4171+
# Ignore the error about using get_proper_type().
4172+
if isinstance(restriction, TypeGuardType): # type: ignore[misc]
4173+
# A type guard forces the new type even if it doesn't overlap the old.
4174+
return restriction.type_guard
41664175
# If the current node is deferred, some variables may get Any types that they
41674176
# otherwise wouldn't have. We don't want to narrow down these since it may
41684177
# produce invalid inferred Optional[Any] types, at least.

mypy/constraints.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,12 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]:
457457
for t, a in zip(template.arg_types, cactual.arg_types):
458458
# Negate direction due to function argument type contravariance.
459459
res.extend(infer_constraints(t, a, neg_op(self.direction)))
460-
res.extend(infer_constraints(template.ret_type, cactual.ret_type,
460+
template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
461+
if template.type_guard is not None:
462+
template_ret_type = template.type_guard
463+
if cactual.type_guard is not None:
464+
cactual_ret_type = cactual.type_guard
465+
res.extend(infer_constraints(template_ret_type, cactual_ret_type,
461466
self.direction))
462467
return res
463468
elif isinstance(self.actual, AnyType):

mypy/expandtype.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def visit_type_var(self, t: TypeVarType) -> Type:
9797

9898
def visit_callable_type(self, t: CallableType) -> Type:
9999
return t.copy_modified(arg_types=self.expand_types(t.arg_types),
100-
ret_type=t.ret_type.accept(self))
100+
ret_type=t.ret_type.accept(self),
101+
type_guard=(t.type_guard.accept(self)
102+
if t.type_guard is not None else None))
101103

102104
def visit_overloaded(self, t: Overloaded) -> Type:
103105
items = [] # type: List[CallableType]

mypy/fixup.py

+2
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ def visit_callable_type(self, ct: CallableType) -> None:
192192
for arg in ct.bound_args:
193193
if arg:
194194
arg.accept(self)
195+
if ct.type_guard is not None:
196+
ct.type_guard.accept(self)
195197

196198
def visit_overloaded(self, t: Overloaded) -> None:
197199
for ct in t.items():

mypy/nodes.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1448,7 +1448,8 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
14481448
class RefExpr(Expression):
14491449
"""Abstract base class for name-like constructs"""
14501450

1451-
__slots__ = ('kind', 'node', 'fullname', 'is_new_def', 'is_inferred_def', 'is_alias_rvalue')
1451+
__slots__ = ('kind', 'node', 'fullname', 'is_new_def', 'is_inferred_def', 'is_alias_rvalue',
1452+
'type_guard')
14521453

14531454
def __init__(self) -> None:
14541455
super().__init__()
@@ -1467,6 +1468,8 @@ def __init__(self) -> None:
14671468
self.is_inferred_def = False
14681469
# Is this expression appears as an rvalue of a valid type alias definition?
14691470
self.is_alias_rvalue = False
1471+
# Cache type guard from callable_type.type_guard
1472+
self.type_guard = None # type: Optional[mypy.types.Type]
14701473

14711474

14721475
class NameExpr(RefExpr):

mypy/test/testcheck.py

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
'check-annotated.test',
9393
'check-parameter-specification.test',
9494
'check-generic-alias.test',
95+
'check-typeguard.test',
9596
]
9697

9798
# Tests that use Python 3.8-only AST features (like expression-scoped ignores):

mypy/typeanal.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt
345345
" and at least one annotation", t)
346346
return AnyType(TypeOfAny.from_error)
347347
return self.anal_type(t.args[0])
348+
elif self.anal_type_guard_arg(t, fullname) is not None:
349+
# In most contexts, TypeGuard[...] acts as an alias for bool (ignoring its args)
350+
return self.named_type('builtins.bool')
348351
return None
349352

350353
def get_omitted_any(self, typ: Type, fullname: Optional[str] = None) -> AnyType:
@@ -524,15 +527,34 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
524527
variables = t.variables
525528
else:
526529
variables = self.bind_function_type_variables(t, t)
530+
special = self.anal_type_guard(t.ret_type)
527531
ret = t.copy_modified(arg_types=self.anal_array(t.arg_types, nested=nested),
528532
ret_type=self.anal_type(t.ret_type, nested=nested),
529533
# If the fallback isn't filled in yet,
530534
# its type will be the falsey FakeInfo
531535
fallback=(t.fallback if t.fallback.type
532536
else self.named_type('builtins.function')),
533-
variables=self.anal_var_defs(variables))
537+
variables=self.anal_var_defs(variables),
538+
type_guard=special,
539+
)
534540
return ret
535541

542+
def anal_type_guard(self, t: Type) -> Optional[Type]:
543+
if isinstance(t, UnboundType):
544+
sym = self.lookup_qualified(t.name, t)
545+
if sym is not None and sym.node is not None:
546+
return self.anal_type_guard_arg(t, sym.node.fullname)
547+
# TODO: What if it's an Instance? Then use t.type.fullname?
548+
return None
549+
550+
def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Optional[Type]:
551+
if fullname in ('typing_extensions.TypeGuard', 'typing.TypeGuard'):
552+
if len(t.args) != 1:
553+
self.fail("TypeGuard must have exactly one type argument", t)
554+
return AnyType(TypeOfAny.from_error)
555+
return self.anal_type(t.args[0])
556+
return None
557+
536558
def visit_overloaded(self, t: Overloaded) -> Type:
537559
# Overloaded types are manually constructed in semanal.py by analyzing the
538560
# AST and combining together the Callable types this visitor converts.

mypy/types.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,16 @@ def copy_modified(self, *,
270270
self.line, self.column)
271271

272272

273+
class TypeGuardType(Type):
274+
"""Only used by find_instance_check() etc."""
275+
def __init__(self, type_guard: Type):
276+
super().__init__(line=type_guard.line, column=type_guard.column)
277+
self.type_guard = type_guard
278+
279+
def __repr__(self) -> str:
280+
return "TypeGuard({})".format(self.type_guard)
281+
282+
273283
class ProperType(Type):
274284
"""Not a type alias.
275285
@@ -1005,6 +1015,7 @@ class CallableType(FunctionLike):
10051015
# tools that consume mypy ASTs
10061016
'def_extras', # Information about original definition we want to serialize.
10071017
# This is used for more detailed error messages.
1018+
'type_guard', # T, if -> TypeGuard[T] (ret_type is bool in this case).
10081019
)
10091020

10101021
def __init__(self,
@@ -1024,6 +1035,7 @@ def __init__(self,
10241035
from_type_type: bool = False,
10251036
bound_args: Sequence[Optional[Type]] = (),
10261037
def_extras: Optional[Dict[str, Any]] = None,
1038+
type_guard: Optional[Type] = None,
10271039
) -> None:
10281040
super().__init__(line, column)
10291041
assert len(arg_types) == len(arg_kinds) == len(arg_names)
@@ -1058,6 +1070,7 @@ def __init__(self,
10581070
not definition.is_static else None}
10591071
else:
10601072
self.def_extras = {}
1073+
self.type_guard = type_guard
10611074

10621075
def copy_modified(self,
10631076
arg_types: Bogus[Sequence[Type]] = _dummy,
@@ -1075,7 +1088,9 @@ def copy_modified(self,
10751088
special_sig: Bogus[Optional[str]] = _dummy,
10761089
from_type_type: Bogus[bool] = _dummy,
10771090
bound_args: Bogus[List[Optional[Type]]] = _dummy,
1078-
def_extras: Bogus[Dict[str, Any]] = _dummy) -> 'CallableType':
1091+
def_extras: Bogus[Dict[str, Any]] = _dummy,
1092+
type_guard: Bogus[Optional[Type]] = _dummy,
1093+
) -> 'CallableType':
10791094
return CallableType(
10801095
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
10811096
arg_kinds=arg_kinds if arg_kinds is not _dummy else self.arg_kinds,
@@ -1094,6 +1109,7 @@ def copy_modified(self,
10941109
from_type_type=from_type_type if from_type_type is not _dummy else self.from_type_type,
10951110
bound_args=bound_args if bound_args is not _dummy else self.bound_args,
10961111
def_extras=def_extras if def_extras is not _dummy else dict(self.def_extras),
1112+
type_guard=type_guard if type_guard is not _dummy else self.type_guard,
10971113
)
10981114

10991115
def var_arg(self) -> Optional[FormalArgument]:
@@ -1255,6 +1271,8 @@ def __eq__(self, other: object) -> bool:
12551271
def serialize(self) -> JsonDict:
12561272
# TODO: As an optimization, leave out everything related to
12571273
# generic functions for non-generic functions.
1274+
assert (self.type_guard is None
1275+
or isinstance(get_proper_type(self.type_guard), Instance)), str(self.type_guard)
12581276
return {'.class': 'CallableType',
12591277
'arg_types': [t.serialize() for t in self.arg_types],
12601278
'arg_kinds': self.arg_kinds,
@@ -1269,6 +1287,7 @@ def serialize(self) -> JsonDict:
12691287
'bound_args': [(None if t is None else t.serialize())
12701288
for t in self.bound_args],
12711289
'def_extras': dict(self.def_extras),
1290+
'type_guard': self.type_guard.serialize() if self.type_guard is not None else None,
12721291
}
12731292

12741293
@classmethod
@@ -1286,7 +1305,9 @@ def deserialize(cls, data: JsonDict) -> 'CallableType':
12861305
implicit=data['implicit'],
12871306
bound_args=[(None if t is None else deserialize_type(t))
12881307
for t in data['bound_args']],
1289-
def_extras=data['def_extras']
1308+
def_extras=data['def_extras'],
1309+
type_guard=(deserialize_type(data['type_guard'])
1310+
if data['type_guard'] is not None else None),
12901311
)
12911312

12921313

@@ -2097,7 +2118,10 @@ def visit_callable_type(self, t: CallableType) -> str:
20972118
s = '({})'.format(s)
20982119

20992120
if not isinstance(get_proper_type(t.ret_type), NoneType):
2100-
s += ' -> {}'.format(t.ret_type.accept(self))
2121+
if t.type_guard is not None:
2122+
s += ' -> TypeGuard[{}]'.format(t.type_guard.accept(self))
2123+
else:
2124+
s += ' -> {}'.format(t.ret_type.accept(self))
21012125

21022126
if t.variables:
21032127
vs = []

test-data/unit/check-python38.test

+9
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,12 @@ def func() -> None:
392392
class Foo:
393393
def __init__(self) -> None:
394394
self.x = 123
395+
396+
[case testWalrusTypeGuard]
397+
from typing_extensions import TypeGuard
398+
def is_float(a: object) -> TypeGuard[float]: pass
399+
def main(a: object) -> None:
400+
if is_float(x := a):
401+
reveal_type(x) # N: Revealed type is 'builtins.float'
402+
reveal_type(a) # N: Revealed type is 'builtins.object'
403+
[builtins fixtures/tuple.pyi]

test-data/unit/check-serialize.test

+15
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,21 @@ def f(x: int) -> int: pass
224224
tmp/a.py:2: note: Revealed type is 'builtins.str'
225225
tmp/a.py:3: error: Unexpected keyword argument "x" for "f"
226226

227+
[case testSerializeTypeGuardFunction]
228+
import a
229+
[file a.py]
230+
import b
231+
[file a.py.2]
232+
import b
233+
reveal_type(b.guard(''))
234+
reveal_type(b.guard)
235+
[file b.py]
236+
from typing_extensions import TypeGuard
237+
def guard(a: object) -> TypeGuard[str]: pass
238+
[builtins fixtures/tuple.pyi]
239+
[out2]
240+
tmp/a.py:2: note: Revealed type is 'builtins.bool'
241+
tmp/a.py:3: note: Revealed type is 'def (a: builtins.object) -> TypeGuard[builtins.str]'
227242
--
228243
-- Classes
229244
--

0 commit comments

Comments
 (0)