Skip to content

Commit f82a019

Browse files
committed
Add support for narrowing Literals using equality
This pull request (finally) adds support for narrowing expressions using Literal types by equality, instead of just identity. For example, the following "tagged union" pattern is now supported: ```python class Foo(TypedDict): key: Literal["A"] blah: int class Bar(TypedDict): key: Literal["B"] something: str x: Union[Foo, Bar] if x.key == "A": reveal_type(x) # Revealed type is 'Foo' else: reveal_type(x) # Revealed type is 'Bar' ``` Previously, this was possible to do only with Enum Literals and the `is` operator, which is perhaps not very intuitive. The main limitation with this pull request is that it'll perform narrowing only if either the LHS or RHS contains an explicit Literal type somewhere. If this limitation is not present, we end up breaking a decent amount of real-world code -- mostly tests -- that do something like this: ```python def some_test_case() -> None: worker = Worker() # Without the limitation, we narrow 'worker.state' to # Literal['ready'] in this assert... assert worker.state == 'ready' worker.start() # ...which subsequently causes this second assert to narrow # worker.state to <uninhabited>, causing the last line to be # unreachable. assert worker.state == 'running' worker.query() ``` I tried for several weeks to find a more intelligent way around this problem, but everything I tried ended up being either insufficient or super-hacky, so I gave up and went for this brute-force solution. The other main limitation is that we perform narrowing only if both the LHS and RHS do not define custom `__eq__` or `__ne__` methods, but this seems like a more reasonable one to me. Resolves #7944.
1 parent 9101707 commit f82a019

File tree

6 files changed

+491
-126
lines changed

6 files changed

+491
-126
lines changed

mypy/checker.py

Lines changed: 110 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import (
88
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable,
9-
Sequence, Mapping, Generic, AbstractSet
9+
Sequence, Mapping, Generic, AbstractSet, Callable
1010
)
1111
from typing_extensions import Final
1212

@@ -50,7 +50,8 @@
5050
erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal,
5151
try_getting_str_literals_from_type, try_getting_int_literals_from_type,
5252
tuple_fallback, is_singleton_type, try_expanding_enum_to_union,
53-
true_only, false_only, function_type, TypeVarExtractor,
53+
true_only, false_only, function_type, TypeVarExtractor, custom_special_method,
54+
is_literal_type_like,
5455
)
5556
from mypy import message_registry
5657
from mypy.subtypes import (
@@ -3890,20 +3891,59 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
38903891

38913892
partial_type_maps = []
38923893
for operator, expr_indices in simplified_operator_list:
3893-
if operator in {'is', 'is not'}:
3894-
if_map, else_map = self.refine_identity_comparison_expression(
3895-
operands,
3896-
operand_types,
3897-
expr_indices,
3898-
narrowable_operand_index_to_hash.keys(),
3899-
)
3900-
elif operator in {'==', '!='}:
3901-
if_map, else_map = self.refine_equality_comparison_expression(
3902-
operands,
3903-
operand_types,
3904-
expr_indices,
3905-
narrowable_operand_index_to_hash.keys(),
3906-
)
3894+
if operator in {'is', 'is not', '==', '!='}:
3895+
# is_valid_target:
3896+
# Controls which types we're allowed to narrow exprs to. Note that
3897+
# we cannot use 'is_literal_type_like' in both cases since doing
3898+
# 'x = 10000 + 1; x is 10001' is not always True in all Python impls.
3899+
#
3900+
# coerce_only_in_literal_context:
3901+
# If true, coerce types into literal types only if one or more of
3902+
# the provided exprs contains an explicit Literal type. This could
3903+
# technically be set to any arbitrary value, but it seems being liberal
3904+
# with narrowing when using 'is' and conservative when using '==' seems
3905+
# to break the least amount of real-world code.
3906+
#
3907+
# should_narrow_by_identity:
3908+
# Set to 'false' only if the user defines custom __eq__ or __ne__ methods
3909+
# that could cause identity-based narrowing to produce invalid results.
3910+
if operator in {'is', 'is not'}:
3911+
is_valid_target = is_singleton_type # type: Callable[[Type], bool]
3912+
coerce_only_in_literal_context = False
3913+
should_narrow_by_identity = True
3914+
else:
3915+
is_valid_target = is_exactly_literal_type
3916+
coerce_only_in_literal_context = True
3917+
3918+
def has_no_custom_eq_checks(t: Type) -> bool:
3919+
return not custom_special_method(t, '__eq__', check_all=False) \
3920+
and not custom_special_method(t, '__ne__', check_all=False)
3921+
expr_types = [operand_types[i] for i in expr_indices]
3922+
should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types))
3923+
3924+
if_map = {} # type: TypeMap
3925+
else_map = {} # type: TypeMap
3926+
if should_narrow_by_identity:
3927+
if_map, else_map = self.refine_identity_comparison_expression(
3928+
operands,
3929+
operand_types,
3930+
expr_indices,
3931+
narrowable_operand_index_to_hash.keys(),
3932+
is_valid_target,
3933+
coerce_only_in_literal_context,
3934+
)
3935+
3936+
# Strictly speaking, we should also skip this check if the objects in the expr
3937+
# chain have custom __eq__ or __ne__ methods. But we (maybe optimistically)
3938+
# assume nobody would actually create a custom objects that considers itself
3939+
# equal to None.
3940+
if if_map == {} and else_map == {}:
3941+
if_map, else_map = self.refine_away_none_in_comparison(
3942+
operands,
3943+
operand_types,
3944+
expr_indices,
3945+
narrowable_operand_index_to_hash.keys(),
3946+
)
39073947
elif operator in {'in', 'not in'}:
39083948
assert len(expr_indices) == 2
39093949
left_index, right_index = expr_indices
@@ -4146,8 +4186,10 @@ def refine_identity_comparison_expression(self,
41464186
operand_types: List[Type],
41474187
chain_indices: List[int],
41484188
narrowable_operand_indices: AbstractSet[int],
4189+
is_valid_target: Callable[[ProperType], bool],
4190+
coerce_only_in_literal_context: bool,
41494191
) -> Tuple[TypeMap, TypeMap]:
4150-
"""Produces conditional type maps refining expressions used in an identity comparison.
4192+
"""Produces conditional type maps refining exprs used in an identity/equality comparison.
41514193
41524194
The 'operands' and 'operand_types' lists should be the full list of operands used
41534195
in the overall comparison expression. The 'chain_indices' list is the list of indices
@@ -4163,30 +4205,45 @@ def refine_identity_comparison_expression(self,
41634205
The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
41644206
to refine the types of: that is, all operands that will potentially be a part of
41654207
the output TypeMaps.
4208+
4209+
Although this function could theoretically try setting the types of the operands
4210+
in the chains to the meet, doing that causes too many issues in real-world code.
4211+
Instead, we use 'is_valid_target' to identify which of the given chain types
4212+
we could plausibly use as the refined type for the expressions in the chain.
4213+
4214+
Similarly, 'coerce_only_in_literal_context' controls whether we should try coercing
4215+
expressions in the chain to a Literal type. Performing this coercion is sometimes
4216+
too aggressive of a narrowing, depending on context.
41664217
"""
4167-
singleton = None # type: Optional[ProperType]
4168-
possible_singleton_indices = []
4218+
should_coerce = True
4219+
if coerce_only_in_literal_context:
4220+
should_coerce = any(is_literal_type_like(operand_types[i]) for i in chain_indices)
4221+
4222+
target = None # type: Optional[Type]
4223+
possible_target_indices = []
41694224
for i in chain_indices:
4170-
coerced_type = coerce_to_literal(operand_types[i])
4171-
if not is_singleton_type(coerced_type):
4225+
expr_type = operand_types[i]
4226+
if should_coerce:
4227+
expr_type = coerce_to_literal(expr_type)
4228+
if not is_valid_target(get_proper_type(expr_type)):
41724229
continue
4173-
if singleton and not is_same_type(singleton, coerced_type):
4174-
# We have multiple disjoint singleton types. So the 'if' branch
4230+
if target and not is_same_type(target, expr_type):
4231+
# We have multiple disjoint target types. So the 'if' branch
41754232
# must be unreachable.
41764233
return None, {}
4177-
singleton = coerced_type
4178-
possible_singleton_indices.append(i)
4234+
target = expr_type
4235+
possible_target_indices.append(i)
41794236

4180-
# There's nothing we can currently infer if none of the operands are singleton types,
4237+
# There's nothing we can currently infer if none of the operands are valid targets,
41814238
# so we end early and infer nothing.
4182-
if singleton is None:
4239+
if target is None:
41834240
return {}, {}
41844241

4185-
# If possible, use an unassignable expression as the singleton.
4186-
# We skip refining the type of the singleton below, so ideally we'd
4242+
# If possible, use an unassignable expression as the target.
4243+
# We skip refining the type of the target below, so ideally we'd
41874244
# want to pick an expression we were going to skip anyways.
41884245
singleton_index = -1
4189-
for i in possible_singleton_indices:
4246+
for i in possible_target_indices:
41904247
if i not in narrowable_operand_indices:
41914248
singleton_index = i
41924249

@@ -4215,20 +4272,21 @@ def refine_identity_comparison_expression(self,
42154272
# currently will just mark the whole branch as unreachable if either operand is
42164273
# narrowed to <uninhabited>.
42174274
if singleton_index == -1:
4218-
singleton_index = possible_singleton_indices[-1]
4275+
singleton_index = possible_target_indices[-1]
42194276

42204277
enum_name = None
4221-
if isinstance(singleton, LiteralType) and singleton.is_enum_literal():
4222-
enum_name = singleton.fallback.type.fullname
4278+
target = get_proper_type(target)
4279+
if isinstance(target, LiteralType) and target.is_enum_literal():
4280+
enum_name = target.fallback.type.fullname
42234281

4224-
target_type = [TypeRange(singleton, is_upper_bound=False)]
4282+
target_type = [TypeRange(target, is_upper_bound=False)]
42254283

42264284
partial_type_maps = []
42274285
for i in chain_indices:
4228-
# If we try refining a singleton against itself, conditional_type_map
4286+
# If we try refining a type against itself, conditional_type_map
42294287
# will end up assuming that the 'else' branch is unreachable. This is
42304288
# typically not what we want: generally the user will intend for the
4231-
# singleton type to be some fixed 'sentinel' value and will want to refine
4289+
# target type to be some fixed 'sentinel' value and will want to refine
42324290
# the other exprs against this one instead.
42334291
if i == singleton_index:
42344292
continue
@@ -4246,17 +4304,16 @@ def refine_identity_comparison_expression(self,
42464304

42474305
return reduce_partial_conditional_maps(partial_type_maps)
42484306

4249-
def refine_equality_comparison_expression(self,
4250-
operands: List[Expression],
4251-
operand_types: List[Type],
4252-
chain_indices: List[int],
4253-
narrowable_operand_indices: AbstractSet[int],
4254-
) -> Tuple[TypeMap, TypeMap]:
4255-
"""Produces conditional type maps refining expressions used in an equality comparison.
4307+
def refine_away_none_in_comparison(self,
4308+
operands: List[Expression],
4309+
operand_types: List[Type],
4310+
chain_indices: List[int],
4311+
narrowable_operand_indices: AbstractSet[int],
4312+
) -> Tuple[TypeMap, TypeMap]:
4313+
"""Produces conditional type maps refining away None in an identity/equality chain.
42564314
4257-
For more details, see the docstring of 'refine_equality_comparison' up above.
4258-
The only difference is that this function is for refining equality operations
4259-
(e.g. 'a == b == c') instead of identity ('a is b is c').
4315+
For more details about what the different arguments mean, see the
4316+
docstring of 'refine_identity_comparison_expression' up above.
42604317
"""
42614318
non_optional_types = []
42624319
for i in chain_indices:
@@ -4749,7 +4806,7 @@ class Foo(Enum):
47494806
return False
47504807

47514808
parent_type = get_proper_type(parent_type)
4752-
member_type = coerce_to_literal(member_type)
4809+
member_type = get_proper_type(coerce_to_literal(member_type))
47534810
if not isinstance(parent_type, FunctionLike) or not isinstance(member_type, LiteralType):
47544811
return False
47554812

@@ -5540,3 +5597,9 @@ def has_bool_item(typ: ProperType) -> bool:
55405597
return any(is_named_instance(item, 'builtins.bool')
55415598
for item in typ.items)
55425599
return False
5600+
5601+
5602+
# TODO: why can't we define this as an inline function?
5603+
# Does mypyc not support them?
5604+
def is_exactly_literal_type(t: Type) -> bool:
5605+
return isinstance(get_proper_type(t), LiteralType)

mypy/checkexpr.py

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
3333
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode,
3434
ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE,
35-
SYMBOL_FUNCBASE_TYPES
3635
)
3736
from mypy.literals import literal
3837
from mypy import nodes
@@ -51,15 +50,16 @@
5150
from mypy import erasetype
5251
from mypy.checkmember import analyze_member_access, type_object_type
5352
from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals
54-
from mypy.checkstrformat import StringFormatterChecker, custom_special_method
53+
from mypy.checkstrformat import StringFormatterChecker
5554
from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars
5655
from mypy.util import split_module_names
5756
from mypy.typevars import fill_typevars
5857
from mypy.visitor import ExpressionVisitor
5958
from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext
6059
from mypy.typeops import (
6160
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
62-
function_type, callable_type, try_getting_str_literals
61+
function_type, callable_type, try_getting_str_literals, custom_special_method,
62+
is_literal_type_like,
6363
)
6464
import mypy.errorcodes as codes
6565

@@ -4266,24 +4266,6 @@ def merge_typevars_in_callables_by_name(
42664266
return output, variables
42674267

42684268

4269-
def is_literal_type_like(t: Optional[Type]) -> bool:
4270-
"""Returns 'true' if the given type context is potentially either a LiteralType,
4271-
a Union of LiteralType, or something similar.
4272-
"""
4273-
t = get_proper_type(t)
4274-
if t is None:
4275-
return False
4276-
elif isinstance(t, LiteralType):
4277-
return True
4278-
elif isinstance(t, UnionType):
4279-
return any(is_literal_type_like(item) for item in t.items)
4280-
elif isinstance(t, TypeVarType):
4281-
return (is_literal_type_like(t.upper_bound)
4282-
or any(is_literal_type_like(item) for item in t.values))
4283-
else:
4284-
return False
4285-
4286-
42874269
def try_getting_literal(typ: Type) -> ProperType:
42884270
"""If possible, get a more precise literal type for a given type."""
42894271
typ = get_proper_type(typ)
@@ -4305,29 +4287,6 @@ def is_expr_literal_type(node: Expression) -> bool:
43054287
return False
43064288

43074289

4308-
def custom_equality_method(typ: Type) -> bool:
4309-
"""Does this type have a custom __eq__() method?"""
4310-
typ = get_proper_type(typ)
4311-
if isinstance(typ, Instance):
4312-
method = typ.type.get('__eq__')
4313-
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
4314-
if method.node.info:
4315-
return not method.node.info.fullname.startswith('builtins.')
4316-
return False
4317-
if isinstance(typ, UnionType):
4318-
return any(custom_equality_method(t) for t in typ.items)
4319-
if isinstance(typ, TupleType):
4320-
return custom_equality_method(tuple_fallback(typ))
4321-
if isinstance(typ, CallableType) and typ.is_type_obj():
4322-
# Look up __eq__ on the metaclass for class objects.
4323-
return custom_equality_method(typ.fallback)
4324-
if isinstance(typ, AnyType):
4325-
# Avoid false positives in uncertain cases.
4326-
return True
4327-
# TODO: support other types (see ExpressionChecker.has_member())?
4328-
return False
4329-
4330-
43314290
def has_bytes_component(typ: Type, py2: bool = False) -> bool:
43324291
"""Is this one of builtin byte types, or a union that contains it?"""
43334292
typ = get_proper_type(typ)

mypy/checkstrformat.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020
from mypy.types import (
2121
Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type, TypeVarType,
22-
CallableType, LiteralType, get_proper_types
22+
LiteralType, get_proper_types
2323
)
2424
from mypy.nodes import (
2525
StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr, CallExpr,
2626
IndexExpr, MemberExpr, TempNode, ARG_POS, ARG_STAR, ARG_NAMED, ARG_STAR2,
27-
SYMBOL_FUNCBASE_TYPES, Decorator, Var, Node, MypyFile, ExpressionStmt, NameExpr, IntExpr
27+
Node, MypyFile, ExpressionStmt, NameExpr, IntExpr
2828
)
2929
import mypy.errorcodes as codes
3030

@@ -35,7 +35,7 @@
3535
from mypy import message_registry
3636
from mypy.messages import MessageBuilder
3737
from mypy.maptype import map_instance_to_supertype
38-
from mypy.typeops import tuple_fallback
38+
from mypy.typeops import custom_special_method
3939
from mypy.subtypes import is_subtype
4040
from mypy.parse import parse
4141

@@ -961,32 +961,3 @@ def has_type_component(typ: Type, fullname: str) -> bool:
961961
elif isinstance(typ, UnionType):
962962
return any(has_type_component(t, fullname) for t in typ.relevant_items())
963963
return False
964-
965-
966-
def custom_special_method(typ: Type, name: str,
967-
check_all: bool = False) -> bool:
968-
"""Does this type have a custom special method such as __format__() or __eq__()?
969-
970-
If check_all is True ensure all items of a union have a custom method, not just some.
971-
"""
972-
typ = get_proper_type(typ)
973-
if isinstance(typ, Instance):
974-
method = typ.type.get(name)
975-
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
976-
if method.node.info:
977-
return not method.node.info.fullname.startswith('builtins.')
978-
return False
979-
if isinstance(typ, UnionType):
980-
if check_all:
981-
return all(custom_special_method(t, name, check_all) for t in typ.items)
982-
return any(custom_special_method(t, name) for t in typ.items)
983-
if isinstance(typ, TupleType):
984-
return custom_special_method(tuple_fallback(typ), name)
985-
if isinstance(typ, CallableType) and typ.is_type_obj():
986-
# Look up __method__ on the metaclass for class objects.
987-
return custom_special_method(typ.fallback, name)
988-
if isinstance(typ, AnyType):
989-
# Avoid false positives in uncertain cases.
990-
return True
991-
# TODO: support other types (see ExpressionChecker.has_member())?
992-
return False

0 commit comments

Comments
 (0)