Skip to content

Commit 9b308b6

Browse files
committed
Refine parent type when narrowing "lookup" expressions
This diff adds support for the following pattern: ```python from typing import Enum, List from enum import Enum class Key(Enum): A = 1 B = 2 class Foo: key: Literal[Key.A] blah: List[int] class Bar: key: Literal[Key.B] something: List[str] x: Union[Foo, Bar] if x.key is Key.A: reveal_type(x) # Revealed type is 'Foo' else: reveal_type(x) # Revealed type is 'Bar' ``` In short, when we do `x.key is Key.A`, we "propagate" the information we discovered about `x.key` up one level to refine the type of `x`. We perform this propagation only when `x` is a Union and only when we are doing member or index lookups into instances, typeddicts, namedtuples, and tuples. For indexing operations, we have one additional limitation: we *must* use a literal expression in order for narrowing to work at all. Using Literal types or Final instances won't work; See python#7905 for more details. To put it another way, this adds support for tagged unions, I guess. This more or less resolves python#7344. We currently don't have support for narrowing based on string or int literals, but that's a separate issue and should be resolved by python#7169 (which I resumed work on earlier this week).
1 parent 84126ab commit 9b308b6

File tree

6 files changed

+471
-20
lines changed

6 files changed

+471
-20
lines changed

mypy/checker.py

Lines changed: 108 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import (
88
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable,
9-
Sequence
9+
Mapping, Sequence
1010
)
1111
from typing_extensions import Final
1212

@@ -47,7 +47,9 @@
4747
)
4848
from mypy.typeops import (
4949
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union,
50-
erase_def_to_union_or_bound, erase_to_union_or_bound,
50+
erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal,
51+
try_getting_str_literals_from_type, try_getting_int_literals_from_type,
52+
tuple_fallback, lookup_attribute_type, is_singleton_type, try_expanding_enum_to_union,
5153
true_only, false_only, function_type,
5254
)
5355
from mypy import message_registry
@@ -72,9 +74,6 @@
7274
from mypy.plugin import Plugin, CheckerPluginInterface
7375
from mypy.sharedparse import BINARY_MAGIC_METHODS
7476
from mypy.scope import Scope
75-
from mypy.typeops import (
76-
tuple_fallback, coerce_to_literal, is_singleton_type, try_expanding_enum_to_union
77-
)
7877
from mypy import state, errorcodes as codes
7978
from mypy.traverser import has_return_statement, all_return_statements
8079
from mypy.errorcodes import ErrorCode
@@ -3709,6 +3708,12 @@ def find_isinstance_check(self, node: Expression
37093708
37103709
Guaranteed to not return None, None. (But may return {}, {})
37113710
"""
3711+
if_map, else_map = self.find_isinstance_check_helper(node)
3712+
new_if_map = propagate_up_typemap_info(self.type_map, if_map)
3713+
new_else_map = propagate_up_typemap_info(self.type_map, else_map)
3714+
return new_if_map, new_else_map
3715+
3716+
def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeMap]:
37123717
type_map = self.type_map
37133718
if is_true_literal(node):
37143719
return {}, None
@@ -3835,23 +3840,23 @@ def find_isinstance_check(self, node: Expression
38353840
else None)
38363841
return if_map, else_map
38373842
elif isinstance(node, OpExpr) and node.op == 'and':
3838-
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
3839-
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
3843+
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left)
3844+
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right)
38403845

38413846
# (e1 and e2) is true if both e1 and e2 are true,
38423847
# and false if at least one of e1 and e2 is false.
38433848
return (and_conditional_maps(left_if_vars, right_if_vars),
38443849
or_conditional_maps(left_else_vars, right_else_vars))
38453850
elif isinstance(node, OpExpr) and node.op == 'or':
3846-
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
3847-
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
3851+
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left)
3852+
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right)
38483853

38493854
# (e1 or e2) is true if at least one of e1 or e2 is true,
38503855
# and false if both e1 and e2 are false.
38513856
return (or_conditional_maps(left_if_vars, right_if_vars),
38523857
and_conditional_maps(left_else_vars, right_else_vars))
38533858
elif isinstance(node, UnaryExpr) and node.op == 'not':
3854-
left, right = self.find_isinstance_check(node.expr)
3859+
left, right = self.find_isinstance_check_helper(node.expr)
38553860
return right, left
38563861

38573862
# Not a supported isinstance check
@@ -4780,3 +4785,96 @@ def has_bool_item(typ: ProperType) -> bool:
47804785
return any(is_named_instance(item, 'builtins.bool')
47814786
for item in typ.items)
47824787
return False
4788+
4789+
4790+
def propagate_up_typemap_info(existing_types: Mapping[Expression, Type],
4791+
new_types: TypeMap) -> TypeMap:
4792+
"""Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types.
4793+
4794+
Specifically, this function accepts two mappings of expression to original types:
4795+
the original mapping (existing_types), and a new mapping (new_types) intended to
4796+
update the original.
4797+
4798+
This function iterates through new_types and attempts to use the information to try
4799+
refining the parent type if
4800+
"""
4801+
if new_types is None:
4802+
return None
4803+
output_map = {}
4804+
for expr, typ in new_types.items():
4805+
# The original inferred type should always be present in the output map, of course
4806+
output_map[expr] = typ
4807+
4808+
# Next, check and see if this expression is one that's attempting to
4809+
# "index" into the parent type. If so, grab both the parent and the "key".
4810+
keys = [] # type: Sequence[Union[str, int]]
4811+
if isinstance(expr, MemberExpr):
4812+
parent_expr = expr.expr
4813+
parent_type = existing_types.get(parent_expr)
4814+
variant_name = expr.name
4815+
keys = [variant_name]
4816+
elif isinstance(expr, IndexExpr):
4817+
parent_expr = expr.base
4818+
parent_type = existing_types.get(parent_expr)
4819+
4820+
variant_type = existing_types.get(expr.index)
4821+
if variant_type is None:
4822+
continue
4823+
4824+
str_literals = try_getting_str_literals_from_type(variant_type)
4825+
if str_literals is not None:
4826+
keys = str_literals
4827+
else:
4828+
int_literals = try_getting_int_literals_from_type(variant_type)
4829+
if int_literals is not None:
4830+
keys = int_literals
4831+
else:
4832+
continue
4833+
else:
4834+
continue
4835+
4836+
# We don't try inferring anything if we've either already inferred something for
4837+
# the parent expression or if the parent somehow doesn't already have an existing type
4838+
if parent_expr in new_types or parent_type is None:
4839+
continue
4840+
4841+
# If the parent isn't a union, we won't be able to perform any useful refinements.
4842+
# So, give up and carry on.
4843+
#
4844+
# TODO: We currently refine just the immediate parent. Should we also try refining
4845+
# any parents of the parents?
4846+
#
4847+
# One quick-and-dirty way of doing this would be to have the caller repeatedly run
4848+
# this function until we seem fixpoint, but that seems expensive.
4849+
parent_type = get_proper_type(parent_type)
4850+
if not isinstance(parent_type, UnionType):
4851+
continue
4852+
4853+
# Take each potential parent type in the union and try "indexing" into it using.
4854+
# Does the resulting type overlap with the deduced type of the original expression?
4855+
# If so, keep the parent type in the union.
4856+
new_parent_types = []
4857+
for item in parent_type.items:
4858+
item = get_proper_type(item)
4859+
member_types = []
4860+
for key in keys:
4861+
t = lookup_attribute_type(item, key)
4862+
if t is not None:
4863+
member_types.append(t)
4864+
member_type_for_item = make_simplified_union(member_types)
4865+
if member_type_for_item is None:
4866+
# We were unable to obtain the member type. So, we give up on refining this
4867+
# parent type entirely.
4868+
new_parent_types = []
4869+
break
4870+
4871+
if is_overlapping_types(member_type_for_item, typ):
4872+
new_parent_types.append(item)
4873+
4874+
# If none of the parent types overlap (if we derived an empty union), either
4875+
# we deliberately aborted or something went wrong. Deriving the uninhabited
4876+
# type seems unhelpful, so let's just skip refining the parent expression.
4877+
if new_parent_types:
4878+
output_map[parent_expr] = make_simplified_union(new_parent_types)
4879+
4880+
return output_map

mypy/checkexpr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2704,6 +2704,9 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr,
27042704
index = e.index
27052705
left_type = get_proper_type(left_type)
27062706

2707+
# Visit the index, just to make sure we have a type for it available
2708+
self.accept(index)
2709+
27072710
if isinstance(left_type, UnionType):
27082711
original_type = original_type or left_type
27092712
return make_simplified_union([self.visit_index_with_type(typ, e,

mypy/test/testcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
'check-isinstance.test',
4747
'check-lists.test',
4848
'check-namedtuple.test',
49+
'check-narrowing.test',
4950
'check-typeddict.test',
5051
'check-type-aliases.test',
5152
'check-ignore.test',

mypy/typeops.py

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
since these may assume that MROs are ready.
66
"""
77

8-
from typing import cast, Optional, List, Sequence, Set
8+
from typing import cast, Optional, List, Sequence, Set, Union, TypeVar, Type as TypingType
99
import sys
1010

1111
from mypy.types import (
1212
TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded,
13-
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType,
13+
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, TypedDictType,
1414
AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types,
1515
copy_type, TypeAliasType
1616
)
@@ -43,6 +43,25 @@ def tuple_fallback(typ: TupleType) -> Instance:
4343
return Instance(info, [join_type_list(typ.items)])
4444

4545

46+
def try_getting_instance_fallback(typ: ProperType) -> Optional[Instance]:
47+
"""Returns the Instance fallback for this type if one exists.
48+
49+
Otherwise, returns None.
50+
"""
51+
if isinstance(typ, Instance):
52+
return typ
53+
elif isinstance(typ, TupleType):
54+
return tuple_fallback(typ)
55+
elif isinstance(typ, TypedDictType):
56+
return typ.fallback
57+
elif isinstance(typ, FunctionLike):
58+
return typ.fallback
59+
elif isinstance(typ, LiteralType):
60+
return typ.fallback
61+
else:
62+
return None
63+
64+
4665
def type_object_type_from_function(signature: FunctionLike,
4766
info: TypeInfo,
4867
def_info: TypeInfo,
@@ -475,27 +494,66 @@ def try_getting_str_literals(expr: Expression, typ: Type) -> Optional[List[str]]
475494
2. 'typ' is a LiteralType containing a string
476495
3. 'typ' is a UnionType containing only LiteralType of strings
477496
"""
478-
typ = get_proper_type(typ)
479-
480497
if isinstance(expr, StrExpr):
481498
return [expr.value]
482499

500+
# TODO: See if we can eliminate this function and call the below one directly
501+
return try_getting_str_literals_from_type(typ)
502+
503+
504+
def try_getting_str_literals_from_type(typ: Type) -> Optional[List[str]]:
505+
"""If the given expression or type corresponds to a string Literal
506+
or a union of string Literals, returns a list of the underlying strings.
507+
Otherwise, returns None.
508+
509+
For example, if we had the type 'Literal["foo", "bar"]' as input, this function
510+
would return a list of strings ["foo", "bar"].
511+
"""
512+
return try_getting_literals_from_type(typ, str, "builtins.str")
513+
514+
515+
def try_getting_int_literals_from_type(typ: Type) -> Optional[List[int]]:
516+
"""If the given expression or type corresponds to an int Literal
517+
or a union of int Literals, returns a list of the underlying ints.
518+
Otherwise, returns None.
519+
520+
For example, if we had the type 'Literal[1, 2, 3]' as input, this function
521+
would return a list of ints [1, 2, 3].
522+
"""
523+
return try_getting_literals_from_type(typ, int, "builtins.int")
524+
525+
526+
T = TypeVar('T')
527+
528+
529+
def try_getting_literals_from_type(typ: Type,
530+
target_literal_type: TypingType[T],
531+
target_fullname: str) -> Optional[List[T]]:
532+
"""If the given expression or type corresponds to a Literal or
533+
union of Literals where the underlying values corresponds to the given
534+
target type, returns a list of those underlying values. Otherwise,
535+
returns None.
536+
"""
537+
typ = get_proper_type(typ)
538+
483539
if isinstance(typ, Instance) and typ.last_known_value is not None:
484540
possible_literals = [typ.last_known_value] # type: List[Type]
485541
elif isinstance(typ, UnionType):
486542
possible_literals = list(typ.items)
487543
else:
488544
possible_literals = [typ]
489545

490-
strings = []
546+
literals = [] # type: List[T]
491547
for lit in get_proper_types(possible_literals):
492-
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == 'builtins.str':
548+
if isinstance(lit, LiteralType) and lit.fallback.type.fullname() == target_fullname:
493549
val = lit.value
494-
assert isinstance(val, str)
495-
strings.append(val)
550+
if isinstance(val, target_literal_type):
551+
literals.append(val)
552+
else:
553+
return None
496554
else:
497555
return None
498-
return strings
556+
return literals
499557

500558

501559
def get_enum_values(typ: Instance) -> List[str]:
@@ -587,3 +645,31 @@ def coerce_to_literal(typ: Type) -> ProperType:
587645
if len(enum_values) == 1:
588646
return LiteralType(value=enum_values[0], fallback=typ)
589647
return typ
648+
649+
650+
def lookup_attribute_type(typ: Type, key: Union[str, int]) -> Optional[Type]:
651+
typ = get_proper_type(typ)
652+
if isinstance(key, int):
653+
# Int keys apply to tuples and namedtuples
654+
if isinstance(typ, TupleType):
655+
try:
656+
return typ.items[key]
657+
except IndexError:
658+
return None
659+
else:
660+
# Str keys apply to typed dicts, named tuples, instances, and anything that has
661+
# an instance fallback
662+
if isinstance(typ, TypedDictType):
663+
return typ.items.get(key)
664+
665+
instance = try_getting_instance_fallback(typ)
666+
if instance is None:
667+
return None
668+
669+
symbol = instance.type.get(key)
670+
if symbol is None:
671+
return None
672+
673+
return expand_type_by_instance(symbol.type, instance)
674+
675+
return None

0 commit comments

Comments
 (0)