Skip to content

Commit 37c57da

Browse files
author
Guido van Rossum
committed
Improve unification for redundant unions and multiple inheritance.
1 parent 4ae9c35 commit 37c57da

File tree

2 files changed

+74
-16
lines changed

2 files changed

+74
-16
lines changed

mypy/join.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def visit_union_type(self, t: UnionType) -> Type:
8787
if is_subtype(self.s, t):
8888
return t
8989
else:
90-
return UnionType(t.items + [self.s])
90+
return UnionType.make_simplified_union([self.s, t])
9191

9292
def visit_error_type(self, t: ErrorType) -> Type:
9393
return t
@@ -235,7 +235,6 @@ def join_instances(t: Instance, s: Instance) -> Type:
235235
236236
Return ErrorType if the result is ambiguous.
237237
"""
238-
239238
if t.type == s.type:
240239
# Simplest case: join two types with the same base type (but
241240
# potentially different arguments).
@@ -264,16 +263,29 @@ def join_instances_via_supertype(t: Instance, s: Instance) -> Type:
264263
return join_types(t.type._promote, s)
265264
elif s.type._promote and is_subtype(s.type._promote, t):
266265
return join_types(t, s.type._promote)
267-
res = s
268-
mapped = map_instance_to_supertype(t, t.type.bases[0].type)
269-
join = join_instances(mapped, res)
270-
# If the join failed, fail. This is a defensive measure (this might
271-
# never happen).
272-
if isinstance(join, ErrorType):
273-
return join
274-
# Now the result must be an Instance, so the cast below cannot fail.
275-
res = cast(Instance, join)
276-
return res
266+
# Compute the "best" supertype of t when joined with s.
267+
# The definition of "best" may evolve; for now it is the one with
268+
# the longest MRO. Ties are broken by using the earlier base.
269+
best = None # type: Type
270+
for base in t.type.bases:
271+
mapped = map_instance_to_supertype(t, base.type)
272+
res = join_instances(mapped, s)
273+
if best is None or is_better(res, best):
274+
best = res
275+
assert best is not None
276+
return best
277+
278+
279+
def is_better(t: Type, s: Type) -> bool:
280+
# Given two possible results from join_instances_via_supertype(),
281+
# indicate whether t is the better one.
282+
if isinstance(t, Instance):
283+
if not isinstance(s, Instance):
284+
return True
285+
# Use len(mro) as a proxy for the better choice.
286+
if len(t.type.mro) > len(s.type.mro):
287+
return True
288+
return False
277289

278290

279291
def is_similar_callables(t: CallableType, s: CallableType) -> bool:

mypy/test/data/check-inference.test

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -599,9 +599,9 @@ g(a)
599599
b = f(A(), B())
600600
g(b)
601601
c = f(A(), D())
602-
g(c) # E: Argument 1 to "g" has incompatible type "object"; expected "I"
602+
g(c) # E: Argument 1 to "g" has incompatible type "J"; expected "I"
603603
d = f(D(), A())
604-
g(d) # E: Argument 1 to "g" has incompatible type "object"; expected "I"
604+
g(d) # E: Argument 1 to "g" has incompatible type "J"; expected "I"
605605
e = f(D(), C())
606606
g(e) # E: Argument 1 to "g" has incompatible type "object"; expected "I"
607607

@@ -646,9 +646,9 @@ def f(a: T, b: T) -> T: pass
646646
def g(x: K) -> None: pass
647647

648648
a = f(B(), C())
649-
g(a) # E: Argument 1 to "g" has incompatible type "object"; expected "K"
649+
g(a) # E: Argument 1 to "g" has incompatible type "J"; expected "K"
650650
b = f(A(), C())
651-
g(b) # E: Argument 1 to "g" has incompatible type "object"; expected "K"
651+
g(b) # E: Argument 1 to "g" has incompatible type "J"; expected "K"
652652
c = f(A(), B())
653653
g(c)
654654

@@ -1593,3 +1593,49 @@ tmp/m.py: note: In function "g":
15931593
tmp/m.py:2: error: "int" not callable
15941594
main: note: In function "f":
15951595
main:3: error: "int" not callable
1596+
1597+
1598+
-- Tests for special cases of unification
1599+
-- --------------------------------------
1600+
1601+
[case testUnificationRedundantUnion]
1602+
from typing import Union
1603+
a = None # type: Union[int, str]
1604+
b = None # type: Union[str, tuple]
1605+
def f(): pass
1606+
def g(x: Union[int, str]): pass
1607+
c = a if f() else b
1608+
g(c) # E: Argument 1 to "g" has incompatible type "Union[int, str, tuple]"; expected "Union[int, str]"
1609+
1610+
[case testUnificationMultipleInheritance]
1611+
class A: pass
1612+
class B:
1613+
def foo(self): pass
1614+
class C(A, B): pass
1615+
def f(): pass
1616+
a1 = B() if f() else C()
1617+
a1.foo()
1618+
a2 = C() if f() else B()
1619+
a2.foo()
1620+
1621+
[case testUnificationMultipleInheritanceAmbiguous]
1622+
# Show that join_instances_via_supertype() breakes ties using the first base class.
1623+
class A1: pass
1624+
class B1:
1625+
def foo1(self): pass
1626+
class C1(A1, B1): pass
1627+
1628+
class A2: pass
1629+
class B2:
1630+
def foo2(self): pass
1631+
class C2(A2, B2): pass
1632+
1633+
class D1(C1, C2): pass
1634+
class D2(C2, C1): pass
1635+
1636+
def f(): pass
1637+
1638+
a1 = D1() if f() else D2()
1639+
a1.foo1()
1640+
a2 = D2() if f() else D1()
1641+
a2.foo2()

0 commit comments

Comments
 (0)