Skip to content

Commit 55a0812

Browse files
authored
Infer unions for ternary expressions (#17427)
Ref #12056 cc @JukkaL Again, let's check the primer...
1 parent f297917 commit 55a0812

File tree

9 files changed

+99
-65
lines changed

9 files changed

+99
-65
lines changed

mypy/checkexpr.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5766,16 +5766,15 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F
57665766
context=if_type_fallback,
57675767
allow_none_return=allow_none_return,
57685768
)
5769-
5770-
# Only create a union type if the type context is a union, to be mostly
5771-
# compatible with older mypy versions where we always did a join.
5772-
#
5773-
# TODO: Always create a union or at least in more cases?
5774-
if isinstance(get_proper_type(self.type_context[-1]), UnionType):
5775-
res: Type = make_simplified_union([if_type, full_context_else_type])
5776-
else:
5777-
res = join.join_types(if_type, else_type)
5778-
5769+
res: Type = make_simplified_union([if_type, else_type])
5770+
if has_uninhabited_component(res) and not isinstance(
5771+
get_proper_type(self.type_context[-1]), UnionType
5772+
):
5773+
# In rare cases with empty collections join may give a better result.
5774+
alternative = join.join_types(if_type, else_type)
5775+
p_alt = get_proper_type(alternative)
5776+
if not isinstance(p_alt, Instance) or p_alt.type.fullname != "builtins.object":
5777+
res = alternative
57795778
return res
57805779

57815780
def analyze_cond_branch(

mypyc/test-data/irbuild-any.test

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def f4(a, n, b):
151151
a :: object
152152
n :: int
153153
b :: bool
154-
r0, r1, r2, r3 :: object
154+
r0 :: union[object, int]
155+
r1, r2 :: object
156+
r3 :: union[int, object]
155157
r4 :: int
156158
L0:
157159
if b goto L1 else goto L2 :: bool

test-data/unit/check-errorcodes.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ a: D = {'x': ''} # E: Incompatible types (expression has type "str", TypedDict
460460
b: D = {'y': ''} # E: Missing key "x" for TypedDict "D" [typeddict-item] \
461461
# E: Extra key "y" for TypedDict "D" [typeddict-unknown-key]
462462
c = D(x=0) if int() else E(x=0, y=0)
463-
c = {} # E: Expected TypedDict key "x" but found no keys [typeddict-item]
463+
c = {} # E: Missing key "x" for TypedDict "D" [typeddict-item]
464464
d: D = {'x': '', 'y': 1} # E: Extra key "y" for TypedDict "D" [typeddict-unknown-key] \
465465
# E: Incompatible types (expression has type "str", TypedDict item "x" has type "int") [typeddict-item]
466466

test-data/unit/check-expressions.test

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,10 +1470,9 @@ if int():
14701470

14711471
[case testConditionalExpressionUnion]
14721472
from typing import Union
1473-
reveal_type(1 if bool() else 2) # N: Revealed type is "builtins.int"
1474-
reveal_type(1 if bool() else '') # N: Revealed type is "builtins.object"
1475-
x: Union[int, str] = reveal_type(1 if bool() else '') \
1476-
# N: Revealed type is "Union[Literal[1]?, Literal['']?]"
1473+
reveal_type(1 if bool() else 2) # N: Revealed type is "Union[Literal[1]?, Literal[2]?]"
1474+
reveal_type(1 if bool() else '') # N: Revealed type is "Union[Literal[1]?, Literal['']?]"
1475+
x: Union[int, str] = reveal_type(1 if bool() else '') # N: Revealed type is "Union[Literal[1]?, Literal['']?]"
14771476
class A:
14781477
pass
14791478
class B(A):
@@ -1487,17 +1486,17 @@ b = B()
14871486
c = C()
14881487
d = D()
14891488
reveal_type(a if bool() else b) # N: Revealed type is "__main__.A"
1490-
reveal_type(b if bool() else c) # N: Revealed type is "builtins.object"
1491-
reveal_type(c if bool() else b) # N: Revealed type is "builtins.object"
1492-
reveal_type(c if bool() else a) # N: Revealed type is "builtins.object"
1493-
reveal_type(d if bool() else b) # N: Revealed type is "__main__.A"
1489+
reveal_type(b if bool() else c) # N: Revealed type is "Union[__main__.B, __main__.C]"
1490+
reveal_type(c if bool() else b) # N: Revealed type is "Union[__main__.C, __main__.B]"
1491+
reveal_type(c if bool() else a) # N: Revealed type is "Union[__main__.C, __main__.A]"
1492+
reveal_type(d if bool() else b) # N: Revealed type is "Union[__main__.D, __main__.B]"
14941493
[builtins fixtures/bool.pyi]
14951494

14961495
[case testConditionalExpressionUnionWithAny]
14971496
from typing import Union, Any
14981497
a: Any
14991498
x: Union[int, str] = reveal_type(a if int() else 1) # N: Revealed type is "Union[Any, Literal[1]?]"
1500-
reveal_type(a if int() else 1) # N: Revealed type is "Any"
1499+
reveal_type(a if int() else 1) # N: Revealed type is "Union[Any, Literal[1]?]"
15011500

15021501
[case testConditionalExpressionStatementNoReturn]
15031502
from typing import List, Union

test-data/unit/check-functions.test

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2250,13 +2250,26 @@ def dec(f: Callable[[A, str], None]) -> Callable[[A, int], None]: pass
22502250
[out]
22512251

22522252
[case testUnknownFunctionNotCallable]
2253+
from typing import TypeVar
2254+
22532255
def f() -> None:
22542256
pass
22552257
def g(x: int) -> None:
22562258
pass
22572259
h = f if bool() else g
2258-
reveal_type(h) # N: Revealed type is "builtins.function"
2259-
h(7) # E: Cannot call function of unknown type
2260+
reveal_type(h) # N: Revealed type is "Union[def (), def (x: builtins.int)]"
2261+
h(7) # E: Too many arguments for "f"
2262+
2263+
T = TypeVar("T")
2264+
def join(x: T, y: T) -> T: ...
2265+
2266+
h2 = join(f, g)
2267+
reveal_type(h2) # N: Revealed type is "builtins.function"
2268+
h2(7) # E: Cannot call function of unknown type
2269+
2270+
h3 = join(g, f)
2271+
reveal_type(h3) # N: Revealed type is "builtins.function"
2272+
h3(7) # E: Cannot call function of unknown type
22602273
[builtins fixtures/bool.pyi]
22612274

22622275
[case testFunctionWithNameUnderscore]

test-data/unit/check-inference-context.test

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ class A: pass
701701
class B(A): pass
702702
class C(A): pass
703703
def f(func: Callable[[T], S], *z: T, r: Optional[S] = None) -> S: pass
704-
reveal_type(f(lambda x: 0 if isinstance(x, B) else 1)) # N: Revealed type is "builtins.int"
704+
reveal_type(f(lambda x: 0 if isinstance(x, B) else 1)) # N: Revealed type is "Union[Literal[0]?, Literal[1]?]"
705705
f(lambda x: 0 if isinstance(x, B) else 1, A())() # E: "int" not callable
706706
f(lambda x: x if isinstance(x, B) else B(), A(), r=B())() # E: "B" not callable
707707
f(
@@ -1391,15 +1391,15 @@ from typing import Union, List, Any
13911391

13921392
def f(x: Union[List[str], Any]) -> None:
13931393
a = x if x else []
1394-
reveal_type(a) # N: Revealed type is "Union[builtins.list[Union[builtins.str, Any]], builtins.list[builtins.str], Any]"
1394+
reveal_type(a) # N: Revealed type is "Union[builtins.list[builtins.str], Any, builtins.list[Union[builtins.str, Any]]]"
13951395
[builtins fixtures/list.pyi]
13961396

13971397
[case testConditionalExpressionWithEmptyIteableAndUnionWithAny]
13981398
from typing import Union, Iterable, Any
13991399

14001400
def f(x: Union[Iterable[str], Any]) -> None:
14011401
a = x if x else []
1402-
reveal_type(a) # N: Revealed type is "Union[builtins.list[Union[builtins.str, Any]], typing.Iterable[builtins.str], Any]"
1402+
reveal_type(a) # N: Revealed type is "Union[typing.Iterable[builtins.str], Any, builtins.list[Union[builtins.str, Any]]]"
14031403
[builtins fixtures/list.pyi]
14041404

14051405
[case testInferMultipleAnyUnionCovariant]

test-data/unit/check-inference.test

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,18 +1438,22 @@ class Wrapper:
14381438

14391439
def f(cond: bool) -> Any:
14401440
f = Wrapper if cond else lambda x: x
1441-
reveal_type(f) # N: Revealed type is "def (x: Any) -> Any"
1441+
reveal_type(f) # N: Revealed type is "Union[def (x: Any) -> __main__.Wrapper, def (x: Any) -> Any]"
14421442
return f(3)
14431443

14441444
def g(cond: bool) -> Any:
14451445
f = lambda x: x if cond else Wrapper
1446-
reveal_type(f) # N: Revealed type is "def (x: Any) -> Any"
1446+
reveal_type(f) # N: Revealed type is "def (x: Any) -> Union[Any, def (x: Any) -> __main__.Wrapper]"
1447+
return f(3)
1448+
1449+
def h(cond: bool) -> Any:
1450+
f = (lambda x: x) if cond else Wrapper
1451+
reveal_type(f) # N: Revealed type is "Union[def (x: Any) -> Any, def (x: Any) -> __main__.Wrapper]"
14471452
return f(3)
14481453

14491454
-- Boolean operators
14501455
-- -----------------
14511456

1452-
14531457
[case testOrOperationWithGenericOperands]
14541458
from typing import List
14551459
a: List[A]

test-data/unit/check-optional.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def lookup_field(name, obj):
395395
attr = None
396396

397397
[case testTernaryWithNone]
398-
reveal_type(None if bool() else 0) # N: Revealed type is "Union[Literal[0]?, None]"
398+
reveal_type(None if bool() else 0) # N: Revealed type is "Union[None, Literal[0]?]"
399399
[builtins fixtures/bool.pyi]
400400

401401
[case testListWithNone]

test-data/unit/check-tuples.test

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,68 +1228,76 @@ x, y = g(z) # E: Argument 1 to "g" has incompatible type "int"; expected "Tuple[
12281228
[out]
12291229

12301230
[case testFixedTupleJoinVarTuple]
1231-
from typing import Tuple
1231+
from typing import Tuple, TypeVar
12321232

12331233
class A: pass
12341234
class B(A): pass
12351235

12361236
fixtup: Tuple[B, B]
12371237

1238+
T = TypeVar("T")
1239+
def join(x: T, y: T) -> T: ...
1240+
12381241
vartup_b: Tuple[B, ...]
1239-
reveal_type(fixtup if int() else vartup_b) # N: Revealed type is "builtins.tuple[__main__.B, ...]"
1240-
reveal_type(vartup_b if int() else fixtup) # N: Revealed type is "builtins.tuple[__main__.B, ...]"
1242+
reveal_type(join(fixtup, vartup_b)) # N: Revealed type is "builtins.tuple[__main__.B, ...]"
1243+
reveal_type(join(vartup_b, fixtup)) # N: Revealed type is "builtins.tuple[__main__.B, ...]"
12411244

12421245
vartup_a: Tuple[A, ...]
1243-
reveal_type(fixtup if int() else vartup_a) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
1244-
reveal_type(vartup_a if int() else fixtup) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
1245-
1246+
reveal_type(join(fixtup, vartup_a)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
1247+
reveal_type(join(vartup_a, fixtup)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
12461248

12471249
[builtins fixtures/tuple.pyi]
12481250
[out]
12491251

12501252
[case testFixedTupleJoinList]
1251-
from typing import Tuple, List
1253+
from typing import Tuple, List, TypeVar
12521254

12531255
class A: pass
12541256
class B(A): pass
12551257

12561258
fixtup: Tuple[B, B]
12571259

1260+
T = TypeVar("T")
1261+
def join(x: T, y: T) -> T: ...
1262+
12581263
lst_b: List[B]
1259-
reveal_type(fixtup if int() else lst_b) # N: Revealed type is "typing.Sequence[__main__.B]"
1260-
reveal_type(lst_b if int() else fixtup) # N: Revealed type is "typing.Sequence[__main__.B]"
1264+
reveal_type(join(fixtup, lst_b)) # N: Revealed type is "typing.Sequence[__main__.B]"
1265+
reveal_type(join(lst_b, fixtup)) # N: Revealed type is "typing.Sequence[__main__.B]"
12611266

12621267
lst_a: List[A]
1263-
reveal_type(fixtup if int() else lst_a) # N: Revealed type is "typing.Sequence[__main__.A]"
1264-
reveal_type(lst_a if int() else fixtup) # N: Revealed type is "typing.Sequence[__main__.A]"
1268+
reveal_type(join(fixtup, lst_a)) # N: Revealed type is "typing.Sequence[__main__.A]"
1269+
reveal_type(join(lst_a, fixtup)) # N: Revealed type is "typing.Sequence[__main__.A]"
12651270

12661271
[builtins fixtures/tuple.pyi]
12671272
[out]
12681273

12691274
[case testEmptyTupleJoin]
1270-
from typing import Tuple, List
1275+
from typing import Tuple, List, TypeVar
12711276

12721277
class A: pass
12731278

12741279
empty = ()
12751280

1281+
T = TypeVar("T")
1282+
def join(x: T, y: T) -> T: ...
1283+
12761284
fixtup: Tuple[A]
1277-
reveal_type(fixtup if int() else empty) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
1278-
reveal_type(empty if int() else fixtup) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
1285+
reveal_type(join(fixtup, empty)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
1286+
reveal_type(join(empty, fixtup)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
12791287

12801288
vartup: Tuple[A, ...]
1281-
reveal_type(empty if int() else vartup) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
1282-
reveal_type(vartup if int() else empty) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
1289+
reveal_type(join(vartup, empty)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
1290+
reveal_type(join(empty, vartup)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
12831291

12841292
lst: List[A]
1285-
reveal_type(empty if int() else lst) # N: Revealed type is "typing.Sequence[__main__.A]"
1286-
reveal_type(lst if int() else empty) # N: Revealed type is "typing.Sequence[__main__.A]"
1293+
reveal_type(join(empty, lst)) # N: Revealed type is "typing.Sequence[__main__.A]"
1294+
reveal_type(join(lst, empty)) # N: Revealed type is "typing.Sequence[__main__.A]"
12871295

12881296
[builtins fixtures/tuple.pyi]
12891297
[out]
12901298

12911299
[case testTupleSubclassJoin]
1292-
from typing import Tuple, NamedTuple
1300+
from typing import Tuple, NamedTuple, TypeVar
12931301

12941302
class NTup(NamedTuple):
12951303
a: bool
@@ -1302,32 +1310,38 @@ ntup: NTup
13021310
subtup: SubTuple
13031311
vartup: SubVarTuple
13041312

1305-
reveal_type(ntup if int() else vartup) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1306-
reveal_type(subtup if int() else vartup) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1313+
T = TypeVar("T")
1314+
def join(x: T, y: T) -> T: ...
1315+
1316+
reveal_type(join(ntup, vartup)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1317+
reveal_type(join(subtup, vartup)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
13071318

13081319
[builtins fixtures/tuple.pyi]
13091320
[out]
13101321

13111322
[case testTupleJoinIrregular]
1312-
from typing import Tuple
1323+
from typing import Tuple, TypeVar
13131324

13141325
tup1: Tuple[bool, int]
13151326
tup2: Tuple[bool]
13161327

1317-
reveal_type(tup1 if int() else tup2) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1318-
reveal_type(tup2 if int() else tup1) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1328+
T = TypeVar("T")
1329+
def join(x: T, y: T) -> T: ...
1330+
1331+
reveal_type(join(tup1, tup2)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1332+
reveal_type(join(tup2, tup1)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
13191333

1320-
reveal_type(tup1 if int() else ()) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1321-
reveal_type(() if int() else tup1) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1334+
reveal_type(join(tup1, ())) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1335+
reveal_type(join((), tup1)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
13221336

1323-
reveal_type(tup2 if int() else ()) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
1324-
reveal_type(() if int() else tup2) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
1337+
reveal_type(join(tup2, ())) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
1338+
reveal_type(join((), tup2)) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
13251339

13261340
[builtins fixtures/tuple.pyi]
13271341
[out]
13281342

13291343
[case testTupleSubclassJoinIrregular]
1330-
from typing import Tuple, NamedTuple
1344+
from typing import Tuple, NamedTuple, TypeVar
13311345

13321346
class NTup1(NamedTuple):
13331347
a: bool
@@ -1342,14 +1356,17 @@ tup1: NTup1
13421356
tup2: NTup2
13431357
subtup: SubTuple
13441358

1345-
reveal_type(tup1 if int() else tup2) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
1346-
reveal_type(tup2 if int() else tup1) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
1359+
T = TypeVar("T")
1360+
def join(x: T, y: T) -> T: ...
1361+
1362+
reveal_type(join(tup1, tup2)) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
1363+
reveal_type(join(tup2, tup1)) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
13471364

1348-
reveal_type(tup1 if int() else subtup) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1349-
reveal_type(subtup if int() else tup1) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1365+
reveal_type(join(tup1, subtup)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1366+
reveal_type(join(subtup, tup1)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
13501367

1351-
reveal_type(tup2 if int() else subtup) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1352-
reveal_type(subtup if int() else tup2) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1368+
reveal_type(join(tup2, subtup)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
1369+
reveal_type(join(subtup, tup2)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
13531370

13541371
[builtins fixtures/tuple.pyi]
13551372
[out]

0 commit comments

Comments
 (0)