Skip to content

Commit a690223

Browse files
committed
more enum-related speedups
As a followup to #9394 address a few more O(n**2) behaviors caused by decomposing enums into unions of literals.
1 parent 0cec4f7 commit a690223

File tree

4 files changed

+95
-18
lines changed

4 files changed

+95
-18
lines changed

mypy/meet.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
6464
if isinstance(declared, UnionType):
6565
return make_simplified_union([narrow_declared_type(x, narrowed)
6666
for x in declared.relevant_items()])
67+
if is_enum_overlapping_union(declared, narrowed):
68+
return narrowed
6769
elif not is_overlapping_types(declared, narrowed,
6870
prohibit_none_typevar_overlap=True):
6971
if state.strict_optional:
@@ -137,6 +139,24 @@ def get_possible_variants(typ: Type) -> List[Type]:
137139
return [typ]
138140

139141

142+
def is_enum_overlapping_union(x: ProperType, y: ProperType) -> bool:
143+
return (
144+
isinstance(x, Instance) and x.type.is_enum and
145+
isinstance(y, UnionType) and
146+
all(isinstance(z, LiteralType) and z.fallback.type == x.type # type: ignore[misc]
147+
for z in y.items)
148+
)
149+
150+
151+
def is_literal_in_union(x: ProperType, y: ProperType) -> bool:
152+
return (
153+
isinstance(x, LiteralType) and isinstance(y, UnionType) and any(
154+
isinstance(z, LiteralType) and z == x # type: ignore[misc]
155+
for z in y.items
156+
)
157+
)
158+
159+
140160
def is_overlapping_types(left: Type,
141161
right: Type,
142162
ignore_promotions: bool = False,
@@ -198,6 +218,18 @@ def _is_overlapping_types(left: Type, right: Type) -> bool:
198218
#
199219
# These checks will also handle the NoneType and UninhabitedType cases for us.
200220

221+
# enums are sometimes expanded into an Union of Literals
222+
# when that happens we want to make sure we treat the two as overlapping
223+
# and crucially, we want to do that *fast* in case the enum is large
224+
# so we do it before expanding variants below to avoid O(n**2) behavior
225+
if (
226+
is_enum_overlapping_union(left, right) or
227+
is_enum_overlapping_union(right, left) or
228+
is_literal_in_union(left, right) or
229+
is_literal_in_union(right, left)
230+
):
231+
return True
232+
201233
if (is_proper_subtype(left, right, ignore_promotions=ignore_promotions)
202234
or is_proper_subtype(right, left, ignore_promotions=ignore_promotions)):
203235
return True

mypy/sametypes.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Sequence
1+
from typing import Sequence, Tuple, Set, List
22

33
from mypy.types import (
44
Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType,
55
UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType,
66
Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType,
77
ProperType, get_proper_type, TypeAliasType, ParamSpecType
88
)
9-
from mypy.typeops import tuple_fallback, make_simplified_union
9+
from mypy.typeops import tuple_fallback, make_simplified_union, is_simple_literal
1010

1111

1212
def is_same_type(left: Type, right: Type) -> bool:
@@ -143,14 +143,32 @@ def visit_literal_type(self, left: LiteralType) -> bool:
143143

144144
def visit_union_type(self, left: UnionType) -> bool:
145145
if isinstance(self.right, UnionType):
146+
# fast path for simple literals
147+
def _extract_literals(u: UnionType) -> Tuple[Set[LiteralType], List[Type]]:
148+
lit = set() # type: Set[LiteralType]
149+
rem = [] # type: List[Type]
150+
for i in u.items:
151+
if is_simple_literal(i):
152+
assert isinstance(i, LiteralType) # type: ignore[misc]
153+
lit.add(i)
154+
else:
155+
rem.append(i)
156+
return lit, rem
157+
158+
left_lit, left_rem = _extract_literals(left)
159+
right_lit, right_rem = _extract_literals(self.right)
160+
161+
if left_lit != right_lit:
162+
return False
163+
146164
# Check that everything in left is in right
147-
for left_item in left.items:
148-
if not any(is_same_type(left_item, right_item) for right_item in self.right.items):
165+
for left_item in left_rem:
166+
if not any(is_same_type(left_item, right_item) for right_item in right_rem):
149167
return False
150168

151169
# Check that everything in right is in left
152-
for right_item in self.right.items:
153-
if not any(is_same_type(right_item, left_item) for left_item in left.items):
170+
for right_item in right_rem:
171+
if not any(is_same_type(right_item, left_item) for left_item in left_rem):
154172
return False
155173

156174
return True

mypy/subtypes.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,19 @@ def visit_overloaded(self, left: Overloaded) -> bool:
498498
return False
499499

500500
def visit_union_type(self, left: UnionType) -> bool:
501+
if isinstance(self.right, Instance):
502+
literal_types = set() # type: Set[Instance]
503+
# avoid redundant check for union of literals
504+
for item in left.items:
505+
if mypy.typeops.is_simple_literal(item):
506+
assert isinstance(item, LiteralType) # type: ignore[misc]
507+
if item.fallback in literal_types:
508+
continue
509+
literal_types.add(item.fallback)
510+
item = item.fallback
511+
if not self._is_subtype(item, self.orig_right):
512+
return False
513+
return True
501514
return all(self._is_subtype(item, self.orig_right) for item in left.items)
502515

503516
def visit_partial_type(self, left: PartialType) -> bool:
@@ -1137,6 +1150,17 @@ def report(*args: Any) -> None:
11371150
return applied
11381151

11391152

1153+
def try_restrict_literal_union(t: UnionType, s: Type) -> Optional[List[Type]]:
1154+
new_items = [] # type: List[Type]
1155+
for i in t.relevant_items():
1156+
it = get_proper_type(i)
1157+
if not mypy.typeops.is_simple_literal(it):
1158+
return None
1159+
if it != s:
1160+
new_items.append(i)
1161+
return new_items
1162+
1163+
11401164
def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) -> Type:
11411165
"""Return t minus s for runtime type assertions.
11421166
@@ -1150,10 +1174,13 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False)
11501174
s = get_proper_type(s)
11511175

11521176
if isinstance(t, UnionType):
1153-
new_items = [restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
1154-
for item in t.relevant_items()
1155-
if (isinstance(get_proper_type(item), AnyType) or
1156-
not covers_at_runtime(item, s, ignore_promotions))]
1177+
new_items = try_restrict_literal_union(t, s) if isinstance(s, LiteralType) else []
1178+
new_items = new_items or [
1179+
restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
1180+
for item in t.relevant_items()
1181+
if (isinstance(get_proper_type(item), AnyType) or
1182+
not covers_at_runtime(item, s, ignore_promotions))
1183+
]
11571184
return UnionType.make_union(new_items)
11581185
elif covers_at_runtime(t, s, ignore_promotions):
11591186
return UninhabitedType()
@@ -1223,11 +1250,11 @@ def _is_proper_subtype(left: Type, right: Type, *,
12231250
right = get_proper_type(right)
12241251

12251252
if isinstance(right, UnionType) and not isinstance(left, UnionType):
1226-
return any([is_proper_subtype(orig_left, item,
1227-
ignore_promotions=ignore_promotions,
1228-
erase_instances=erase_instances,
1229-
keep_erased_types=keep_erased_types)
1230-
for item in right.items])
1253+
return any(is_proper_subtype(orig_left, item,
1254+
ignore_promotions=ignore_promotions,
1255+
erase_instances=erase_instances,
1256+
keep_erased_types=keep_erased_types)
1257+
for item in right.items)
12311258
return left.accept(ProperSubtypeVisitor(orig_right,
12321259
ignore_promotions=ignore_promotions,
12331260
erase_instances=erase_instances,
@@ -1418,7 +1445,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
14181445
return False
14191446

14201447
def visit_union_type(self, left: UnionType) -> bool:
1421-
return all([self._is_proper_subtype(item, self.orig_right) for item in left.items])
1448+
return all(self._is_proper_subtype(item, self.orig_right) for item in left.items)
14221449

14231450
def visit_partial_type(self, left: PartialType) -> bool:
14241451
# TODO: What's the right thing to do here?

mypy/typeops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,13 +294,13 @@ def callable_corresponding_argument(typ: CallableType,
294294
return by_name if by_name is not None else by_pos
295295

296296

297-
def is_simple_literal(t: ProperType) -> bool:
297+
def is_simple_literal(t: Type) -> bool:
298298
"""
299299
Whether a type is a simple enough literal to allow for fast Union simplification
300300
301301
For now this means enum or string
302302
"""
303-
return isinstance(t, LiteralType) and (
303+
return isinstance(t, LiteralType) and ( # type: ignore[misc]
304304
t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str'
305305
)
306306

0 commit comments

Comments
 (0)