Skip to content

Commit c53e0ca

Browse files
committed
Fixes to union simplification and isinstance
The main change is that unions containing Any are no longer simplified to just Any. This required changes in various other places to keep the existing semantics, and resulted in some fixes to existing test cases.
1 parent 9d15e25 commit c53e0ca

11 files changed

+245
-41
lines changed

mypy/checkexpr.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from mypy import messages
3333
from mypy.infer import infer_type_arguments, infer_function_type_arguments
3434
from mypy import join
35-
from mypy.meet import meet_simple
35+
from mypy.meet import narrow_declared_type
3636
from mypy.maptype import map_instance_to_supertype
3737
from mypy.subtypes import is_subtype, is_equivalent
3838
from mypy import applytype
@@ -2213,8 +2213,7 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type:
22132213
if expr.literal >= LITERAL_TYPE:
22142214
restriction = self.chk.binder.get(expr)
22152215
if restriction:
2216-
ans = meet_simple(known_type, restriction)
2217-
return ans
2216+
return narrow_declared_type(known_type, restriction)
22182217
return known_type
22192218

22202219

mypy/meet.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,25 @@ def meet_types(s: Type, t: Type) -> Type:
2525
return t.accept(TypeMeetVisitor(s))
2626

2727

28-
def meet_simple(s: Type, t: Type, default_right: bool = True) -> Type:
29-
if s == t:
30-
return s
31-
if isinstance(s, UnionType):
32-
return UnionType.make_simplified_union([meet_types(x, t) for x in s.items])
33-
elif not is_overlapping_types(s, t, use_promotions=True):
28+
def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
29+
"""Return the declared type narrowed down to another type."""
30+
# TODO: What are the reasons for not just using meet_types()?
31+
if declared == narrowed:
32+
return declared
33+
if isinstance(declared, UnionType):
34+
return UnionType.make_simplified_union([narrow_declared_type(x, narrowed)
35+
for x in declared.items])
36+
elif not is_overlapping_types(declared, narrowed, use_promotions=True):
3437
if experiments.STRICT_OPTIONAL:
3538
return UninhabitedType()
3639
else:
3740
return NoneTyp()
38-
else:
39-
if default_right:
40-
return t
41-
else:
42-
return s
41+
elif isinstance(narrowed, UnionType):
42+
return UnionType.make_simplified_union([narrow_declared_type(declared, x)
43+
for x in narrowed.items])
44+
elif isinstance(narrowed, AnyType):
45+
return narrowed
46+
return meet_types(declared, narrowed)
4347

4448

4549
def is_overlapping_types(t: Type, s: Type, use_promotions: bool = False) -> bool:

mypy/subtypes.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,12 +508,17 @@ def restrict_subtype_away(t: Type, s: Type) -> Type:
508508
return t
509509

510510

511-
def is_proper_subtype(t: Type, s: Type) -> bool:
511+
def is_proper_subtype(left: Type, right: Type) -> bool:
512512
"""Check if t is a proper subtype of s?
513513
514514
For proper subtypes, there's no need to rely on compatibility due to
515515
Any types. Any instance type t is also a proper subtype of t.
516516
"""
517+
if isinstance(right, UnionType) and not isinstance(left, UnionType):
518+
return any([is_proper_subtype(left, item)
519+
for item in right.items])
520+
return left.accept(ProperSubtypeVisitor(right))
521+
517522
# FIX tuple types
518523
if isinstance(s, UnionType):
519524
if isinstance(t, UnionType):
@@ -544,6 +549,122 @@ def check_argument(left: Type, right: Type, variance: int) -> bool:
544549
return sametypes.is_same_type(t, s)
545550

546551

552+
class ProperSubtypeVisitor(TypeVisitor[bool]):
553+
def __init__(self, right: Type) -> None:
554+
self.right = right
555+
556+
def visit_unbound_type(self, left: UnboundType) -> bool:
557+
return True
558+
559+
def visit_error_type(self, left: ErrorType) -> bool:
560+
return False
561+
562+
def visit_type_list(self, left: TypeList) -> bool:
563+
assert False, 'Should not happen'
564+
565+
def visit_any(self, left: AnyType) -> bool:
566+
return isinstance(self.right, AnyType)
567+
568+
def visit_void(self, left: Void) -> bool:
569+
return True
570+
571+
def visit_none_type(self, left: NoneTyp) -> bool:
572+
if experiments.STRICT_OPTIONAL:
573+
return (isinstance(self.right, NoneTyp) or
574+
is_named_instance(self.right, 'builtins.object'))
575+
else:
576+
return not isinstance(self.right, Void)
577+
578+
def visit_uninhabited_type(self, left: UninhabitedType) -> bool:
579+
return not isinstance(self.right, Void)
580+
581+
def visit_erased_type(self, left: ErasedType) -> bool:
582+
return True
583+
584+
def visit_deleted_type(self, left: DeletedType) -> bool:
585+
return True
586+
587+
def visit_instance(self, left: Instance) -> bool:
588+
if isinstance(self.right, Instance):
589+
if not left.type.has_base(self.right.type.fullname()):
590+
return False
591+
592+
def check_argument(leftarg: Type, rightarg: Type, variance: int) -> bool:
593+
if variance == COVARIANT:
594+
return is_proper_subtype(leftarg, rightarg)
595+
elif variance == CONTRAVARIANT:
596+
return is_proper_subtype(rightarg, leftarg)
597+
else:
598+
return sametypes.is_same_type(leftarg, rightarg)
599+
600+
# Map left type to corresponding right instances.
601+
left = map_instance_to_supertype(left, self.right.type)
602+
603+
return all(check_argument(ta, ra, tvar.variance) for ta, ra, tvar in
604+
zip(left.args, self.right.args, self.right.type.defn.type_vars))
605+
return False
606+
607+
def visit_type_var(self, left: TypeVarType) -> bool:
608+
if isinstance(self.right, TypeVarType) and left.id == self.right.id:
609+
return True
610+
return is_proper_subtype(left.upper_bound, self.right)
611+
612+
def visit_callable_type(self, left: CallableType) -> bool:
613+
# TODO: Implement this properly
614+
return is_subtype(left, self.right)
615+
616+
def visit_tuple_type(self, left: TupleType) -> bool:
617+
right = self.right
618+
if isinstance(right, Instance):
619+
if (is_named_instance(right, 'builtins.tuple') or
620+
is_named_instance(right, 'typing.Iterable') or
621+
is_named_instance(right, 'typing.Container') or
622+
is_named_instance(right, 'typing.Sequence') or
623+
is_named_instance(right, 'typing.Reversible')):
624+
if not right.args:
625+
return False
626+
iter_type = right.args[0]
627+
return all(is_proper_subtype(li, iter_type) for li in left.items)
628+
return is_proper_subtype(left.fallback, right)
629+
elif isinstance(right, TupleType):
630+
if len(left.items) != len(right.items):
631+
return False
632+
for l, r in zip(left.items, right.items):
633+
if not is_proper_subtype(l, r):
634+
return False
635+
return is_proper_subtype(left.fallback, right.fallback)
636+
return False
637+
638+
def visit_typeddict_type(self, left: TypedDictType) -> bool:
639+
# TODO: Does it make sense to support TypedDict here?
640+
return False
641+
642+
def visit_overloaded(self, left: Overloaded) -> bool:
643+
# TODO: What's the right thing to do here?
644+
return False
645+
646+
def visit_union_type(self, left: UnionType) -> bool:
647+
return all([is_proper_subtype(item, self.right) for item in left.items])
648+
649+
def visit_partial_type(self, left: PartialType) -> bool:
650+
# TODO: What's the right thing to do here?
651+
return False
652+
653+
def visit_type_type(self, left: TypeType) -> bool:
654+
right = self.right
655+
if isinstance(right, TypeType):
656+
return is_proper_subtype(left.item, right.item)
657+
if isinstance(right, CallableType):
658+
# This is unsound, we don't check the __init__ signature.
659+
return right.is_type_obj() and is_proper_subtype(left.item, right.ret_type)
660+
if isinstance(right, Instance):
661+
if right.type.fullname() in ('builtins.type', 'builtins.object'):
662+
return True
663+
item = left.item
664+
return isinstance(item, Instance) and is_proper_subtype(item, right.type.metaclass_type)
665+
return False
666+
667+
547668
def is_more_precise(t: Type, s: Type) -> bool:
548669
"""Check if t is a more precise type than s.
549670

mypy/types.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -995,11 +995,7 @@ def make_simplified_union(items: List[Type], line: int = -1, column: int = -1) -
995995
all_items.append(typ)
996996
items = all_items
997997

998-
if any(isinstance(typ, AnyType) for typ in items):
999-
return AnyType()
1000-
1001-
from mypy.subtypes import is_subtype
1002-
from mypy.sametypes import is_same_type
998+
from mypy.subtypes import is_proper_subtype
1003999

10041000
removed = set() # type: Set[int]
10051001
for i, ti in enumerate(items):
@@ -1008,9 +1004,7 @@ def make_simplified_union(items: List[Type], line: int = -1, column: int = -1) -
10081004
cbt = cbf = False
10091005
for j, tj in enumerate(items):
10101006
if (i != j
1011-
and is_subtype(tj, ti)
1012-
and (not (isinstance(tj, Instance) and tj.type.fallback_to_any)
1013-
or is_same_type(ti, tj))):
1007+
and is_proper_subtype(tj, ti)):
10141008
removed.add(j)
10151009
cbt = cbt or tj.can_be_true
10161010
cbf = cbf or tj.can_be_false

test-data/unit/check-classes.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2388,9 +2388,9 @@ class A: pass
23882388
class B(A): pass
23892389

23902390
@overload
2391-
def f(a: Type[A]) -> int: pass # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
2391+
def f(a: Type[B]) -> int: pass # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
23922392
@overload
2393-
def f(a: Type[B]) -> str: pass
2393+
def f(a: Type[A]) -> str: pass
23942394
[builtins fixtures/classmethod.pyi]
23952395
[out]
23962396

test-data/unit/check-dynamic-typing.test

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ n = 0
7979
d in a # E: Unsupported right operand type for in ("A")
8080
d and a
8181
d or a
82-
c = d and b # Unintuitive type inference?
83-
c = d or b # Unintuitive type inference?
82+
c = d and b # E: Incompatible types in assignment (expression has type "Union[Any, bool]", variable has type "C")
83+
c = d or b # E: Incompatible types in assignment (expression has type "Union[Any, bool]", variable has type "C")
8484

8585
c = d + a
8686
c = d - a
@@ -123,8 +123,8 @@ n = 0
123123
a and d
124124
a or d
125125
c = a in d
126-
c = b and d # Unintuitive type inference?
127-
c = b or d # Unintuitive type inference?
126+
c = b and d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C")
127+
c = b or d # E: Incompatible types in assignment (expression has type "Union[bool, Any]", variable has type "C")
128128
b = a + d
129129
b = a / d
130130

test-data/unit/check-generics.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ if not isinstance(s, str):
779779

780780
z = None # type: TNode # Same as TNode[Any]
781781
z.x
782-
z.foo() # Any simplifies Union to Any now. This test should be updated after #2197
782+
z.foo() # E: Some element of union has no attribute "foo"
783783

784784
[builtins fixtures/isinstance.pyi]
785785

test-data/unit/check-isinstance.test

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ def f(x: Union[List[int], List[str], int]) -> None:
399399
a + 'x' # E: Unsupported operand types for + ("int" and "str")
400400

401401
# type of a?
402+
reveal_type(x) # E: Revealed type is 'Union[builtins.list[builtins.int], builtins.list[builtins.str]]'
402403
x + 1 # E: Unsupported operand types for + (likely involving Union)
403404
else:
404405
x[0] # E: Value of type "int" is not indexable

test-data/unit/check-optional.test

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -547,9 +547,31 @@ a = None # type: Any
547547
reveal_type(u(C(), None)) # E: Revealed type is 'Union[builtins.None, __main__.C*]'
548548
reveal_type(u(None, C())) # E: Revealed type is 'Union[__main__.C*, builtins.None]'
549549

550-
# This will be fixed later
551-
reveal_type(u(a, None)) # E: Revealed type is 'Any'
552-
reveal_type(u(None, a)) # E: Revealed type is 'Any'
550+
reveal_type(u(a, None)) # E: Revealed type is 'Union[builtins.None, Any]'
551+
reveal_type(u(None, a)) # E: Revealed type is 'Union[Any, builtins.None]'
553552

554553
reveal_type(u(1, None)) # E: Revealed type is 'Union[builtins.None, builtins.int*]'
555554
reveal_type(u(None, 1)) # E: Revealed type is 'Union[builtins.int*, builtins.None]'
555+
556+
[case testOptionalAndAnyBaseClass]
557+
from typing import Any, Optional
558+
class C(Any):
559+
pass
560+
x = None # type: Optional[C]
561+
x.foo() # E: Some element of union has no attribute "foo"
562+
563+
[case testUnionSimplificationWithStrictOptional]
564+
from typing import Any, TypeVar, Union
565+
class C(Any): pass
566+
T = TypeVar('T')
567+
S = TypeVar('S')
568+
def u(x: T, y: S) -> Union[S, T]: pass
569+
a = None # type: Any
570+
571+
# Test both orders
572+
reveal_type(u(C(), None)) # E: Revealed type is 'Union[builtins.None, __main__.C*]'
573+
reveal_type(u(None, C())) # E: Revealed type is 'Union[__main__.C*, builtins.None]'
574+
reveal_type(u(a, None)) # E: Revealed type is 'Union[builtins.None, Any]'
575+
reveal_type(u(None, a)) # E: Revealed type is 'Union[Any, builtins.None]'
576+
reveal_type(u(1, None)) # E: Revealed type is 'Union[builtins.None, builtins.int*]'
577+
reveal_type(u(None, 1)) # E: Revealed type is 'Union[builtins.int*, builtins.None]'

test-data/unit/check-statements.test

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -723,11 +723,11 @@ try:
723723
except BaseException as e1:
724724
reveal_type(e1) # E: Revealed type is 'builtins.BaseException'
725725
except (E1, BaseException) as e2:
726-
reveal_type(e2) # E: Revealed type is 'Any'
726+
reveal_type(e2) # E: Revealed type is 'Union[Any, builtins.BaseException]'
727727
except (E1, E2) as e3:
728-
reveal_type(e3) # E: Revealed type is 'Any'
728+
reveal_type(e3) # E: Revealed type is 'Union[Any, __main__.E2]'
729729
except (E1, E2, BaseException) as e4:
730-
reveal_type(e4) # E: Revealed type is 'Any'
730+
reveal_type(e4) # E: Revealed type is 'Union[Any, builtins.BaseException]'
731731

732732
try: pass
733733
except E1 as e1:

0 commit comments

Comments
 (0)