Skip to content

Commit 7e7eedd

Browse files
authored
Infer type for partial generic type from assignment (#8036)
Code like this no longer requires a type annotation: ``` a = [] if foo(): a = [1] ``` Work towards #1055.
1 parent b7465de commit 7e7eedd

File tree

3 files changed

+69
-29
lines changed

3 files changed

+69
-29
lines changed

mypy/checker.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,6 +2042,7 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
20422042
self.check_assignment_to_multiple_lvalues(lvalue.items, rvalue, rvalue,
20432043
infer_lvalue_type)
20442044
else:
2045+
self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue)
20452046
lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue)
20462047
# If we're assigning to __getattr__ or similar methods, check that the signature is
20472048
# valid.
@@ -2141,6 +2142,37 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
21412142
rvalue_type = remove_instance_last_known_values(rvalue_type)
21422143
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
21432144

2145+
def try_infer_partial_generic_type_from_assignment(self,
2146+
lvalue: Lvalue,
2147+
rvalue: Expression) -> None:
2148+
"""Try to infer a precise type for partial generic type from assignment.
2149+
2150+
Example where this happens:
2151+
2152+
x = []
2153+
if foo():
2154+
x = [1] # Infer List[int] as type of 'x'
2155+
"""
2156+
if (isinstance(lvalue, NameExpr)
2157+
and isinstance(lvalue.node, Var)
2158+
and isinstance(lvalue.node.type, PartialType)):
2159+
var = lvalue.node
2160+
typ = lvalue.node.type
2161+
if typ.type is None:
2162+
return
2163+
partial_types = self.find_partial_types(var)
2164+
if partial_types is None:
2165+
return
2166+
rvalue_type = self.expr_checker.accept(rvalue)
2167+
rvalue_type = get_proper_type(rvalue_type)
2168+
if isinstance(rvalue_type, Instance):
2169+
if rvalue_type.type == typ.type:
2170+
var.type = rvalue_type
2171+
del partial_types[var]
2172+
elif isinstance(rvalue_type, AnyType):
2173+
var.type = fill_typevars_with_any(typ.type)
2174+
del partial_types[var]
2175+
21442176
def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type],
21452177
rvalue: Expression) -> bool:
21462178
lvalue_node = lvalue.node

mypy/errors.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,14 +598,22 @@ def remove_duplicates(self, errors: List[ErrorTuple]) -> List[ErrorTuple]:
598598
i = 0
599599
while i < len(errors):
600600
dup = False
601+
# Use slightly special formatting for member conflicts reporting.
602+
conflicts_notes = False
603+
j = i - 1
604+
while j >= 0 and errors[j][0] == errors[i][0]:
605+
if errors[j][4].strip() == 'Got:':
606+
conflicts_notes = True
607+
j -= 1
601608
j = i - 1
602609
while (j >= 0 and errors[j][0] == errors[i][0] and
603610
errors[j][1] == errors[i][1]):
604611
if (errors[j][3] == errors[i][3] and
605612
# Allow duplicate notes in overload conflicts reporting.
606-
not (errors[i][3] == 'note' and
607-
errors[i][4].strip() in allowed_duplicates
608-
or errors[i][4].strip().startswith('def ')) and
613+
not ((errors[i][3] == 'note' and
614+
errors[i][4].strip() in allowed_duplicates)
615+
or (errors[i][4].strip().startswith('def ') and
616+
conflicts_notes)) and
609617
errors[j][4] == errors[i][4]): # ignore column
610618
dup = True
611619
break

test-data/unit/check-inference.test

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,34 +1367,29 @@ a = []
13671367
a.append(1)
13681368
a.append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
13691369
[builtins fixtures/list.pyi]
1370-
[out]
13711370

13721371
[case testInferListInitializedToEmptyUsingUpdate]
13731372
a = []
13741373
a.extend([''])
13751374
a.append(0) # E: Argument 1 to "append" of "list" has incompatible type "int"; expected "str"
13761375
[builtins fixtures/list.pyi]
1377-
[out]
13781376

13791377
[case testInferListInitializedToEmptyAndNotAnnotated]
13801378
a = [] # E: Need type annotation for 'a' (hint: "a: List[<type>] = ...")
13811379
[builtins fixtures/list.pyi]
1382-
[out]
13831380

13841381
[case testInferListInitializedToEmptyAndReadBeforeAppend]
13851382
a = [] # E: Need type annotation for 'a' (hint: "a: List[<type>] = ...")
13861383
if a: pass
13871384
a.xyz # E: "List[Any]" has no attribute "xyz"
13881385
a.append('')
13891386
[builtins fixtures/list.pyi]
1390-
[out]
13911387

13921388
[case testInferListInitializedToEmptyAndIncompleteTypeInAppend]
13931389
a = [] # E: Need type annotation for 'a' (hint: "a: List[<type>] = ...")
13941390
a.append([])
13951391
a() # E: "List[Any]" not callable
13961392
[builtins fixtures/list.pyi]
1397-
[out]
13981393

13991394
[case testInferListInitializedToEmptyAndMultipleAssignment]
14001395
a, b = [], []
@@ -1403,15 +1398,13 @@ b.append('')
14031398
a() # E: "List[int]" not callable
14041399
b() # E: "List[str]" not callable
14051400
[builtins fixtures/list.pyi]
1406-
[out]
14071401

14081402
[case testInferListInitializedToEmptyInFunction]
14091403
def f() -> None:
14101404
a = []
14111405
a.append(1)
14121406
a.append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
14131407
[builtins fixtures/list.pyi]
1414-
[out]
14151408

14161409
[case testInferListInitializedToEmptyAndNotAnnotatedInFunction]
14171410
def f() -> None:
@@ -1422,7 +1415,6 @@ def g() -> None: pass
14221415
a = []
14231416
a.append(1)
14241417
[builtins fixtures/list.pyi]
1425-
[out]
14261418

14271419
[case testInferListInitializedToEmptyAndReadBeforeAppendInFunction]
14281420
def f() -> None:
@@ -1431,15 +1423,13 @@ def f() -> None:
14311423
a.xyz # E: "List[Any]" has no attribute "xyz"
14321424
a.append('')
14331425
[builtins fixtures/list.pyi]
1434-
[out]
14351426

14361427
[case testInferListInitializedToEmptyInClassBody]
14371428
class A:
14381429
a = []
14391430
a.append(1)
14401431
a.append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
14411432
[builtins fixtures/list.pyi]
1442-
[out]
14431433

14441434
[case testInferListInitializedToEmptyAndNotAnnotatedInClassBody]
14451435
class A:
@@ -1449,7 +1439,6 @@ class B:
14491439
a = []
14501440
a.append(1)
14511441
[builtins fixtures/list.pyi]
1452-
[out]
14531442

14541443
[case testInferListInitializedToEmptyInMethod]
14551444
class A:
@@ -1458,14 +1447,12 @@ class A:
14581447
a.append(1)
14591448
a.append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
14601449
[builtins fixtures/list.pyi]
1461-
[out]
14621450

14631451
[case testInferListInitializedToEmptyAndNotAnnotatedInMethod]
14641452
class A:
14651453
def f(self) -> None:
14661454
a = [] # E: Need type annotation for 'a' (hint: "a: List[<type>] = ...")
14671455
[builtins fixtures/list.pyi]
1468-
[out]
14691456

14701457
[case testInferListInitializedToEmptyInMethodViaAttribute]
14711458
class A:
@@ -1475,7 +1462,6 @@ class A:
14751462
self.a.append(1)
14761463
self.a.append('')
14771464
[builtins fixtures/list.pyi]
1478-
[out]
14791465

14801466
[case testInferListInitializedToEmptyInClassBodyAndOverriden]
14811467
from typing import List
@@ -1490,57 +1476,49 @@ class B(A):
14901476
def x(self) -> List[int]: # E: Signature of "x" incompatible with supertype "A"
14911477
return [123]
14921478
[builtins fixtures/list.pyi]
1493-
[out]
14941479

14951480
[case testInferSetInitializedToEmpty]
14961481
a = set()
14971482
a.add(1)
14981483
a.add('') # E: Argument 1 to "add" of "set" has incompatible type "str"; expected "int"
14991484
[builtins fixtures/set.pyi]
1500-
[out]
15011485

15021486
[case testInferSetInitializedToEmptyUsingDiscard]
15031487
a = set()
15041488
a.discard('')
15051489
a.add(0) # E: Argument 1 to "add" of "set" has incompatible type "int"; expected "str"
15061490
[builtins fixtures/set.pyi]
1507-
[out]
15081491

15091492
[case testInferSetInitializedToEmptyUsingUpdate]
15101493
a = set()
15111494
a.update({0})
15121495
a.add('') # E: Argument 1 to "add" of "set" has incompatible type "str"; expected "int"
15131496
[builtins fixtures/set.pyi]
1514-
[out]
15151497

15161498
[case testInferDictInitializedToEmpty]
15171499
a = {}
15181500
a[1] = ''
15191501
a() # E: "Dict[int, str]" not callable
15201502
[builtins fixtures/dict.pyi]
1521-
[out]
15221503

15231504
[case testInferDictInitializedToEmptyUsingUpdate]
15241505
a = {}
15251506
a.update({'': 42})
15261507
a() # E: "Dict[str, int]" not callable
15271508
[builtins fixtures/dict.pyi]
1528-
[out]
15291509

15301510
[case testInferDictInitializedToEmptyUsingUpdateError]
15311511
a = {} # E: Need type annotation for 'a' (hint: "a: Dict[<type>, <type>] = ...")
15321512
a.update([1, 2]) # E: Argument 1 to "update" of "dict" has incompatible type "List[int]"; expected "Mapping[Any, Any]"
15331513
a() # E: "Dict[Any, Any]" not callable
15341514
[builtins fixtures/dict.pyi]
1535-
[out]
15361515

15371516
[case testInferDictInitializedToEmptyAndIncompleteTypeInUpdate]
15381517
a = {} # E: Need type annotation for 'a' (hint: "a: Dict[<type>, <type>] = ...")
15391518
a[1] = {}
15401519
b = {} # E: Need type annotation for 'b' (hint: "b: Dict[<type>, <type>] = ...")
15411520
b[{}] = 1
15421521
[builtins fixtures/dict.pyi]
1543-
[out]
15441522

15451523
[case testInferDictInitializedToEmptyAndUpdatedFromMethod]
15461524
map = {}
@@ -1557,20 +1535,42 @@ def add():
15571535
[case testSpecialCaseEmptyListInitialization]
15581536
def f(blocks: Any): # E: Name 'Any' is not defined \
15591537
# N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any")
1560-
to_process = [] # E: Need type annotation for 'to_process' (hint: "to_process: List[<type>] = ...")
1538+
to_process = []
15611539
to_process = list(blocks)
15621540
[builtins fixtures/list.pyi]
1563-
[out]
15641541

15651542
[case testSpecialCaseEmptyListInitialization2]
15661543
def f(blocks: object):
1567-
to_process = [] # E: Need type annotation for 'to_process' (hint: "to_process: List[<type>] = ...")
1544+
to_process = []
15681545
to_process = list(blocks) # E: No overload variant of "list" matches argument type "object" \
15691546
# N: Possible overload variant: \
15701547
# N: def [T] __init__(self, x: Iterable[T]) -> List[T] \
15711548
# N: <1 more non-matching overload not shown>
15721549
[builtins fixtures/list.pyi]
1573-
[out]
1550+
1551+
[case testInferListInitializedToEmptyAndAssigned]
1552+
a = []
1553+
if bool():
1554+
a = [1]
1555+
reveal_type(a) # N: Revealed type is 'builtins.list[builtins.int*]'
1556+
1557+
def f():
1558+
return [1]
1559+
b = []
1560+
if bool():
1561+
b = f()
1562+
reveal_type(b) # N: Revealed type is 'builtins.list[Any]'
1563+
1564+
d = {}
1565+
if bool():
1566+
d = {1: 'x'}
1567+
reveal_type(d) # N: Revealed type is 'builtins.dict[builtins.int*, builtins.str*]'
1568+
1569+
dd = {} # E: Need type annotation for 'dd' (hint: "dd: Dict[<type>, <type>] = ...")
1570+
if bool():
1571+
dd = [1] # E: Incompatible types in assignment (expression has type "List[int]", variable has type "Dict[Any, Any]")
1572+
reveal_type(dd) # N: Revealed type is 'builtins.dict[Any, Any]'
1573+
[builtins fixtures/dict.pyi]
15741574

15751575

15761576
-- Inferring types of variables first initialized to None (partial types)

0 commit comments

Comments
 (0)