Skip to content

Commit bcb3747

Browse files
Implement TypeIs (PEP 742) (#16898)
Co-authored-by: Marc Mueller <[email protected]>
1 parent 3c87af2 commit bcb3747

19 files changed

+962
-19
lines changed

docs/source/error_code_list2.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,3 +555,19 @@ Correct usage:
555555
556556
When this code is enabled, using ``reveal_locals`` is always an error,
557557
because there's no way one can import it.
558+
559+
.. _code-narrowed-type-not-subtype:
560+
561+
Check that ``TypeIs`` narrows types [narrowed-type-not-subtype]
562+
---------------------------------------------------------------
563+
564+
:pep:`742` requires that when ``TypeIs`` is used, the narrowed
565+
type must be a subtype of the original type::
566+
567+
from typing_extensions import TypeIs
568+
569+
def f(x: int) -> TypeIs[str]: # Error, str is not a subtype of int
570+
...
571+
572+
def g(x: object) -> TypeIs[str]: # OK
573+
...

mypy/applytype.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,15 @@ def apply_generic_arguments(
137137
arg_types=[expand_type(at, id_to_type) for at in callable.arg_types]
138138
)
139139

140-
# Apply arguments to TypeGuard if any.
140+
# Apply arguments to TypeGuard and TypeIs if any.
141141
if callable.type_guard is not None:
142142
type_guard = expand_type(callable.type_guard, id_to_type)
143143
else:
144144
type_guard = None
145+
if callable.type_is is not None:
146+
type_is = expand_type(callable.type_is, id_to_type)
147+
else:
148+
type_is = None
145149

146150
# The callable may retain some type vars if only some were applied.
147151
# TODO: move apply_poly() logic from checkexpr.py here when new inference
@@ -164,4 +168,5 @@ def apply_generic_arguments(
164168
ret_type=expand_type(callable.ret_type, id_to_type),
165169
variables=remaining_tvars,
166170
type_guard=type_guard,
171+
type_is=type_is,
167172
)

mypy/checker.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,22 @@ def check_func_def(
12031203
# visible from *inside* of this function/method.
12041204
ref_type: Type | None = self.scope.active_self_type()
12051205

1206+
if typ.type_is:
1207+
arg_index = 0
1208+
# For methods and classmethods, we want the second parameter
1209+
if ref_type is not None and (not defn.is_static or defn.name == "__new__"):
1210+
arg_index = 1
1211+
if arg_index < len(typ.arg_types) and not is_subtype(
1212+
typ.type_is, typ.arg_types[arg_index]
1213+
):
1214+
self.fail(
1215+
message_registry.NARROWED_TYPE_NOT_SUBTYPE.format(
1216+
format_type(typ.type_is, self.options),
1217+
format_type(typ.arg_types[arg_index], self.options),
1218+
),
1219+
item,
1220+
)
1221+
12061222
# Store argument types.
12071223
for i in range(len(typ.arg_types)):
12081224
arg_type = typ.arg_types[i]
@@ -2178,6 +2194,8 @@ def check_override(
21782194
elif isinstance(original, CallableType) and isinstance(override, CallableType):
21792195
if original.type_guard is not None and override.type_guard is None:
21802196
fail = True
2197+
if original.type_is is not None and override.type_is is None:
2198+
fail = True
21812199

21822200
if is_private(name):
21832201
fail = False
@@ -5643,7 +5661,7 @@ def combine_maps(list_maps: list[TypeMap]) -> TypeMap:
56435661
def find_isinstance_check(self, node: Expression) -> tuple[TypeMap, TypeMap]:
56445662
"""Find any isinstance checks (within a chain of ands). Includes
56455663
implicit and explicit checks for None and calls to callable.
5646-
Also includes TypeGuard functions.
5664+
Also includes TypeGuard and TypeIs functions.
56475665
56485666
Return value is a map of variables to their types if the condition
56495667
is true and a map of variables to their types if the condition is false.
@@ -5695,7 +5713,7 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
56955713
if literal(expr) == LITERAL_TYPE and attr and len(attr) == 1:
56965714
return self.hasattr_type_maps(expr, self.lookup_type(expr), attr[0])
56975715
elif isinstance(node.callee, RefExpr):
5698-
if node.callee.type_guard is not None:
5716+
if node.callee.type_guard is not None or node.callee.type_is is not None:
56995717
# TODO: Follow *args, **kwargs
57005718
if node.arg_kinds[0] != nodes.ARG_POS:
57015719
# the first argument might be used as a kwarg
@@ -5721,15 +5739,31 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
57215739
# we want the idx-th variable to be narrowed
57225740
expr = collapse_walrus(node.args[idx])
57235741
else:
5724-
self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node)
5742+
kind = (
5743+
"guard" if node.callee.type_guard is not None else "narrower"
5744+
)
5745+
self.fail(
5746+
message_registry.TYPE_GUARD_POS_ARG_REQUIRED.format(kind), node
5747+
)
57255748
return {}, {}
57265749
if literal(expr) == LITERAL_TYPE:
57275750
# Note: we wrap the target type, so that we can special case later.
57285751
# Namely, for isinstance() we use a normal meet, while TypeGuard is
57295752
# considered "always right" (i.e. even if the types are not overlapping).
57305753
# Also note that a care must be taken to unwrap this back at read places
57315754
# where we use this to narrow down declared type.
5732-
return {expr: TypeGuardedType(node.callee.type_guard)}, {}
5755+
if node.callee.type_guard is not None:
5756+
return {expr: TypeGuardedType(node.callee.type_guard)}, {}
5757+
else:
5758+
assert node.callee.type_is is not None
5759+
return conditional_types_to_typemaps(
5760+
expr,
5761+
*self.conditional_types_with_intersection(
5762+
self.lookup_type(expr),
5763+
[TypeRange(node.callee.type_is, is_upper_bound=False)],
5764+
expr,
5765+
),
5766+
)
57335767
elif isinstance(node, ComparisonExpr):
57345768
# Step 1: Obtain the types of each operand and whether or not we can
57355769
# narrow their types. (For example, we shouldn't try narrowing the

mypy/checkexpr.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,13 +1451,12 @@ def check_call_expr_with_callee_type(
14511451
object_type=object_type,
14521452
)
14531453
proper_callee = get_proper_type(callee_type)
1454-
if (
1455-
isinstance(e.callee, RefExpr)
1456-
and isinstance(proper_callee, CallableType)
1457-
and proper_callee.type_guard is not None
1458-
):
1454+
if isinstance(e.callee, RefExpr) and isinstance(proper_callee, CallableType):
14591455
# Cache it for find_isinstance_check()
1460-
e.callee.type_guard = proper_callee.type_guard
1456+
if proper_callee.type_guard is not None:
1457+
e.callee.type_guard = proper_callee.type_guard
1458+
if proper_callee.type_is is not None:
1459+
e.callee.type_is = proper_callee.type_is
14611460
return ret_type
14621461

14631462
def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str) -> Type:
@@ -5283,7 +5282,7 @@ def infer_lambda_type_using_context(
52835282
# is a constructor -- but this fallback doesn't make sense for lambdas.
52845283
callable_ctx = callable_ctx.copy_modified(fallback=self.named_type("builtins.function"))
52855284

5286-
if callable_ctx.type_guard is not None:
5285+
if callable_ctx.type_guard is not None or callable_ctx.type_is is not None:
52875286
# Lambda's return type cannot be treated as a `TypeGuard`,
52885287
# because it is implicit. And `TypeGuard`s must be explicit.
52895288
# See https://github.com/python/mypy/issues/9927

mypy/constraints.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,10 +1018,22 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
10181018
param_spec = template.param_spec()
10191019

10201020
template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
1021-
if template.type_guard is not None:
1021+
if template.type_guard is not None and cactual.type_guard is not None:
10221022
template_ret_type = template.type_guard
1023-
if cactual.type_guard is not None:
10241023
cactual_ret_type = cactual.type_guard
1024+
elif template.type_guard is not None:
1025+
template_ret_type = AnyType(TypeOfAny.special_form)
1026+
elif cactual.type_guard is not None:
1027+
cactual_ret_type = AnyType(TypeOfAny.special_form)
1028+
1029+
if template.type_is is not None and cactual.type_is is not None:
1030+
template_ret_type = template.type_is
1031+
cactual_ret_type = cactual.type_is
1032+
elif template.type_is is not None:
1033+
template_ret_type = AnyType(TypeOfAny.special_form)
1034+
elif cactual.type_is is not None:
1035+
cactual_ret_type = AnyType(TypeOfAny.special_form)
1036+
10251037
res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))
10261038

10271039
if param_spec is None:

mypy/errorcodes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,5 +281,11 @@ def __hash__(self) -> int:
281281
sub_code_of=MISC,
282282
)
283283

284+
NARROWED_TYPE_NOT_SUBTYPE: Final[ErrorCode] = ErrorCode(
285+
"narrowed-type-not-subtype",
286+
"Warn if a TypeIs function's narrowed type is not a subtype of the original type",
287+
"General",
288+
)
289+
284290
# This copy will not include any error codes defined later in the plugins.
285291
mypy_error_codes = error_codes.copy()

mypy/expandtype.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
351351
arg_names=t.arg_names[:-2] + repl.arg_names,
352352
ret_type=t.ret_type.accept(self),
353353
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
354+
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
354355
imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds),
355356
variables=[*repl.variables, *t.variables],
356357
)
@@ -384,6 +385,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
384385
arg_types=arg_types,
385386
ret_type=t.ret_type.accept(self),
386387
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
388+
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
387389
)
388390
if needs_normalization:
389391
return expanded.with_normalized_var_args()

mypy/fixup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ def visit_callable_type(self, ct: CallableType) -> None:
270270
arg.accept(self)
271271
if ct.type_guard is not None:
272272
ct.type_guard.accept(self)
273+
if ct.type_is is not None:
274+
ct.type_is.accept(self)
273275

274276
def visit_overloaded(self, t: Overloaded) -> None:
275277
for ct in t.items:

mypy/message_registry.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
262262

263263
CONTIGUOUS_ITERABLE_EXPECTED: Final = ErrorMessage("Contiguous iterable with same type expected")
264264
ITERABLE_TYPE_EXPECTED: Final = ErrorMessage("Invalid type '{}' for *expr (iterable expected)")
265-
TYPE_GUARD_POS_ARG_REQUIRED: Final = ErrorMessage("Type guard requires positional argument")
265+
TYPE_GUARD_POS_ARG_REQUIRED: Final = ErrorMessage("Type {} requires positional argument")
266266

267267
# Match Statement
268268
MISSING_MATCH_ARGS: Final = 'Class "{}" doesn\'t define "__match_args__"'
@@ -324,3 +324,6 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
324324
ARG_NAME_EXPECTED_STRING_LITERAL: Final = ErrorMessage(
325325
"Expected string literal for argument name, got {}", codes.SYNTAX
326326
)
327+
NARROWED_TYPE_NOT_SUBTYPE: Final = ErrorMessage(
328+
"Narrowed type {} is not a subtype of input type {}", codes.NARROWED_TYPE_NOT_SUBTYPE
329+
)

mypy/messages.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2643,6 +2643,8 @@ def format_literal_value(typ: LiteralType) -> str:
26432643
elif isinstance(func, CallableType):
26442644
if func.type_guard is not None:
26452645
return_type = f"TypeGuard[{format(func.type_guard)}]"
2646+
elif func.type_is is not None:
2647+
return_type = f"TypeIs[{format(func.type_is)}]"
26462648
else:
26472649
return_type = format(func.ret_type)
26482650
if func.is_ellipsis_args:
@@ -2859,6 +2861,8 @@ def [T <: int] f(self, x: int, y: T) -> None
28592861
s += " -> "
28602862
if tp.type_guard is not None:
28612863
s += f"TypeGuard[{format_type_bare(tp.type_guard, options)}]"
2864+
elif tp.type_is is not None:
2865+
s += f"TypeIs[{format_type_bare(tp.type_is, options)}]"
28622866
else:
28632867
s += format_type_bare(tp.ret_type, options)
28642868

0 commit comments

Comments
 (0)