@@ -270,6 +270,16 @@ def copy_modified(self, *,
270
270
self .line , self .column )
271
271
272
272
273
+ class TypeGuardType (Type ):
274
+ """Only used by find_instance_check() etc."""
275
+ def __init__ (self , type_guard : Type ):
276
+ super ().__init__ (line = type_guard .line , column = type_guard .column )
277
+ self .type_guard = type_guard
278
+
279
+ def __repr__ (self ) -> str :
280
+ return "TypeGuard({})" .format (self .type_guard )
281
+
282
+
273
283
class ProperType (Type ):
274
284
"""Not a type alias.
275
285
@@ -1005,6 +1015,7 @@ class CallableType(FunctionLike):
1005
1015
# tools that consume mypy ASTs
1006
1016
'def_extras' , # Information about original definition we want to serialize.
1007
1017
# This is used for more detailed error messages.
1018
+ 'type_guard' , # T, if -> TypeGuard[T] (ret_type is bool in this case).
1008
1019
)
1009
1020
1010
1021
def __init__ (self ,
@@ -1024,6 +1035,7 @@ def __init__(self,
1024
1035
from_type_type : bool = False ,
1025
1036
bound_args : Sequence [Optional [Type ]] = (),
1026
1037
def_extras : Optional [Dict [str , Any ]] = None ,
1038
+ type_guard : Optional [Type ] = None ,
1027
1039
) -> None :
1028
1040
super ().__init__ (line , column )
1029
1041
assert len (arg_types ) == len (arg_kinds ) == len (arg_names )
@@ -1058,6 +1070,7 @@ def __init__(self,
1058
1070
not definition .is_static else None }
1059
1071
else :
1060
1072
self .def_extras = {}
1073
+ self .type_guard = type_guard
1061
1074
1062
1075
def copy_modified (self ,
1063
1076
arg_types : Bogus [Sequence [Type ]] = _dummy ,
@@ -1075,7 +1088,9 @@ def copy_modified(self,
1075
1088
special_sig : Bogus [Optional [str ]] = _dummy ,
1076
1089
from_type_type : Bogus [bool ] = _dummy ,
1077
1090
bound_args : Bogus [List [Optional [Type ]]] = _dummy ,
1078
- def_extras : Bogus [Dict [str , Any ]] = _dummy ) -> 'CallableType' :
1091
+ def_extras : Bogus [Dict [str , Any ]] = _dummy ,
1092
+ type_guard : Bogus [Optional [Type ]] = _dummy ,
1093
+ ) -> 'CallableType' :
1079
1094
return CallableType (
1080
1095
arg_types = arg_types if arg_types is not _dummy else self .arg_types ,
1081
1096
arg_kinds = arg_kinds if arg_kinds is not _dummy else self .arg_kinds ,
@@ -1094,6 +1109,7 @@ def copy_modified(self,
1094
1109
from_type_type = from_type_type if from_type_type is not _dummy else self .from_type_type ,
1095
1110
bound_args = bound_args if bound_args is not _dummy else self .bound_args ,
1096
1111
def_extras = def_extras if def_extras is not _dummy else dict (self .def_extras ),
1112
+ type_guard = type_guard if type_guard is not _dummy else self .type_guard ,
1097
1113
)
1098
1114
1099
1115
def var_arg (self ) -> Optional [FormalArgument ]:
@@ -1255,6 +1271,8 @@ def __eq__(self, other: object) -> bool:
1255
1271
def serialize (self ) -> JsonDict :
1256
1272
# TODO: As an optimization, leave out everything related to
1257
1273
# generic functions for non-generic functions.
1274
+ assert (self .type_guard is None
1275
+ or isinstance (get_proper_type (self .type_guard ), Instance )), str (self .type_guard )
1258
1276
return {'.class' : 'CallableType' ,
1259
1277
'arg_types' : [t .serialize () for t in self .arg_types ],
1260
1278
'arg_kinds' : self .arg_kinds ,
@@ -1269,6 +1287,7 @@ def serialize(self) -> JsonDict:
1269
1287
'bound_args' : [(None if t is None else t .serialize ())
1270
1288
for t in self .bound_args ],
1271
1289
'def_extras' : dict (self .def_extras ),
1290
+ 'type_guard' : self .type_guard .serialize () if self .type_guard is not None else None ,
1272
1291
}
1273
1292
1274
1293
@classmethod
@@ -1286,7 +1305,9 @@ def deserialize(cls, data: JsonDict) -> 'CallableType':
1286
1305
implicit = data ['implicit' ],
1287
1306
bound_args = [(None if t is None else deserialize_type (t ))
1288
1307
for t in data ['bound_args' ]],
1289
- def_extras = data ['def_extras' ]
1308
+ def_extras = data ['def_extras' ],
1309
+ type_guard = (deserialize_type (data ['type_guard' ])
1310
+ if data ['type_guard' ] is not None else None ),
1290
1311
)
1291
1312
1292
1313
@@ -2097,7 +2118,10 @@ def visit_callable_type(self, t: CallableType) -> str:
2097
2118
s = '({})' .format (s )
2098
2119
2099
2120
if not isinstance (get_proper_type (t .ret_type ), NoneType ):
2100
- s += ' -> {}' .format (t .ret_type .accept (self ))
2121
+ if t .type_guard is not None :
2122
+ s += ' -> TypeGuard[{}]' .format (t .type_guard .accept (self ))
2123
+ else :
2124
+ s += ' -> {}' .format (t .ret_type .accept (self ))
2101
2125
2102
2126
if t .variables :
2103
2127
vs = []
0 commit comments