Skip to content

Commit b470951

Browse files
committed
Merge branch 'collection-type-infer'
Partially addresses #1055.
2 parents 1d04eb7 + 8af8d14 commit b470951

19 files changed

+360
-60
lines changed

mypy/checker.py

Lines changed: 89 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from mypy.types import (
3030
Type, AnyType, CallableType, Void, FunctionLike, Overloaded, TupleType,
3131
Instance, NoneTyp, UnboundType, ErrorType, TypeTranslator, strip_type,
32-
UnionType, TypeVarType,
32+
UnionType, TypeVarType, PartialType
3333
)
3434
from mypy.sametypes import is_same_type
3535
from mypy.messages import MessageBuilder
@@ -332,7 +332,8 @@ class TypeChecker(NodeVisitor[Type]):
332332
breaking_out = False
333333
# Do weak type checking in this file
334334
weak_opts = set() # type: Set[str]
335-
335+
# Stack of collections of variables with partial types
336+
partial_types = None # type: List[Dict[Var, Context]]
336337
globals = None # type: SymbolTable
337338
locals = None # type: SymbolTable
338339
modules = None # type: Dict[str, MypyFile]
@@ -358,6 +359,7 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile],
358359
self.dynamic_funcs = []
359360
self.function_stack = []
360361
self.weak_opts = set() # type: Set[str]
362+
self.partial_types = []
361363

362364
def visit_file(self, file_node: MypyFile, path: str) -> None:
363365
"""Type check a mypy file with the given path."""
@@ -367,10 +369,12 @@ def visit_file(self, file_node: MypyFile, path: str) -> None:
367369
self.globals = file_node.names
368370
self.locals = None
369371
self.weak_opts = file_node.weak_opts
372+
self.enter_partial_types()
370373

371374
for d in file_node.defs:
372375
self.accept(d)
373376

377+
self.leave_partial_types()
374378
self.errors.set_ignored_lines(set())
375379

376380
def accept(self, node: Node, type_context: Type = None) -> Type:
@@ -461,6 +465,8 @@ def check_func_item(self, defn: FuncItem,
461465
if fdef:
462466
self.errors.push_function(fdef.name())
463467

468+
self.enter_partial_types()
469+
464470
typ = self.function_type(defn)
465471
if type_override:
466472
typ = type_override
@@ -469,6 +475,8 @@ def check_func_item(self, defn: FuncItem,
469475
else:
470476
raise RuntimeError('Not supported')
471477

478+
self.leave_partial_types()
479+
472480
if fdef:
473481
self.errors.pop_function()
474482

@@ -864,12 +872,14 @@ def visit_class_def(self, defn: ClassDef) -> Type:
864872
"""Type check a class definition."""
865873
typ = defn.info
866874
self.errors.push_type(defn.name)
875+
self.enter_partial_types()
867876
old_binder = self.binder
868877
self.binder = ConditionalTypeBinder()
869878
self.binder.push_frame()
870879
self.accept(defn.defs)
871880
self.binder = old_binder
872881
self.check_multiple_inheritance(typ)
882+
self.leave_partial_types()
873883
self.errors.pop_type()
874884

875885
def check_multiple_inheritance(self, typ: TypeInfo) -> None:
@@ -1237,11 +1247,14 @@ def infer_variable_type(self, name: Var, lvalue: Node,
12371247
elif isinstance(init_type, Void):
12381248
self.check_not_void(init_type, context)
12391249
self.set_inference_error_fallback_type(name, lvalue, init_type, context)
1240-
elif not self.is_valid_inferred_type(init_type):
1241-
# We cannot use the type of the initialization expression for type
1242-
# inference (it's not specific enough).
1243-
self.fail(messages.NEED_ANNOTATION_FOR_VAR, context)
1244-
self.set_inference_error_fallback_type(name, lvalue, init_type, context)
1250+
elif not is_valid_inferred_type(init_type):
1251+
# We cannot use the type of the initialization expression for full type
1252+
# inference (it's not specific enough), but we might be able to give
1253+
# partial type which will be made more specific later. A partial type
1254+
# gets generated in assignment like 'x = []' where item type is not known.
1255+
if not self.infer_partial_type(name, lvalue, init_type):
1256+
self.fail(messages.NEED_ANNOTATION_FOR_VAR, context)
1257+
self.set_inference_error_fallback_type(name, lvalue, init_type, context)
12451258
else:
12461259
# Infer type of the target.
12471260

@@ -1250,6 +1263,21 @@ def infer_variable_type(self, name: Var, lvalue: Node,
12501263

12511264
self.set_inferred_type(name, lvalue, init_type)
12521265

1266+
def infer_partial_type(self, name: Var, lvalue: Node, init_type: Type) -> bool:
1267+
if not isinstance(init_type, Instance):
1268+
return False
1269+
fullname = init_type.type.fullname()
1270+
if ((fullname == 'builtins.list' or fullname == 'builtins.set' or
1271+
fullname == 'builtins.dict')
1272+
and isinstance(init_type.args[0], NoneTyp)
1273+
and (fullname != 'builtins.dict' or isinstance(init_type.args[1], NoneTyp))
1274+
and isinstance(lvalue, NameExpr)):
1275+
partial_type = PartialType(init_type.type, name)
1276+
self.set_inferred_type(name, lvalue, partial_type)
1277+
self.partial_types[-1][name] = lvalue
1278+
return True
1279+
return False
1280+
12531281
def set_inferred_type(self, var: Var, lvalue: Node, type: Type) -> None:
12541282
"""Store inferred variable type.
12551283
@@ -1275,23 +1303,6 @@ def set_inference_error_fallback_type(self, var: Var, lvalue: Node, type: Type,
12751303
if context.get_line() in self.errors.ignored_lines:
12761304
self.set_inferred_type(var, lvalue, AnyType())
12771305

1278-
def is_valid_inferred_type(self, typ: Type) -> bool:
1279-
"""Is an inferred type invalid?
1280-
1281-
Examples include the None type or a type with a None component.
1282-
"""
1283-
if is_same_type(typ, NoneTyp()):
1284-
return False
1285-
elif isinstance(typ, Instance):
1286-
for arg in typ.args:
1287-
if not self.is_valid_inferred_type(arg):
1288-
return False
1289-
elif isinstance(typ, TupleType):
1290-
for item in typ.items:
1291-
if not self.is_valid_inferred_type(item):
1292-
return False
1293-
return True
1294-
12951306
def narrow_type_from_binder(self, expr: Node, known_type: Type) -> Type:
12961307
if expr.literal >= LITERAL_TYPE:
12971308
restriction = self.binder.get(expr)
@@ -1323,6 +1334,7 @@ def check_indexed_assignment(self, lvalue: IndexExpr,
13231334
13241335
The lvalue argument is the base[index] expression.
13251336
"""
1337+
self.try_infer_partial_type_from_indexed_assignment(lvalue, rvalue)
13261338
basetype = self.accept(lvalue.base)
13271339
method_type = self.expr_checker.analyze_external_member_access(
13281340
'__setitem__', basetype, context)
@@ -1331,6 +1343,26 @@ def check_indexed_assignment(self, lvalue: IndexExpr,
13311343
[nodes.ARG_POS, nodes.ARG_POS],
13321344
context)
13331345

1346+
def try_infer_partial_type_from_indexed_assignment(
1347+
self, lvalue: IndexExpr, rvalue: Node) -> None:
1348+
# TODO: Should we share some of this with try_infer_partial_type?
1349+
partial_types = self.partial_types[-1]
1350+
if not partial_types:
1351+
# Fast path leave -- no partial types in the current scope.
1352+
return
1353+
if isinstance(lvalue.base, RefExpr):
1354+
var = lvalue.base.node
1355+
if var in partial_types:
1356+
var = cast(Var, var)
1357+
typename = cast(Instance, var.type).type.fullname()
1358+
if typename == 'builtins.dict':
1359+
# TODO: Don't infer things twice.
1360+
key_type = self.accept(lvalue.index)
1361+
value_type = self.accept(rvalue)
1362+
if is_valid_inferred_type(key_type) and is_valid_inferred_type(value_type):
1363+
var.type = self.named_generic_type('builtins.dict', [key_type, value_type])
1364+
del partial_types[var]
1365+
13341366
def visit_expression_stmt(self, s: ExpressionStmt) -> Type:
13351367
self.accept(s.expr)
13361368

@@ -2032,6 +2064,21 @@ def enter(self) -> None:
20322064
def leave(self) -> None:
20332065
self.locals = None
20342066

2067+
def enter_partial_types(self) -> None:
2068+
"""Push a new scope for collecting partial types."""
2069+
self.partial_types.append({})
2070+
2071+
def leave_partial_types(self) -> None:
2072+
"""Pop partial type scope.
2073+
2074+
Also report errors for variables which still have partial
2075+
types, i.e. we couldn't infer a complete type.
2076+
"""
2077+
partial_types = self.partial_types.pop()
2078+
for var, context in partial_types.items():
2079+
self.msg.fail(messages.NEED_ANNOTATION_FOR_VAR, context)
2080+
var.type = AnyType()
2081+
20352082
def is_within_function(self) -> bool:
20362083
"""Are we currently type checking within a function?
20372084
@@ -2289,3 +2336,21 @@ def infer_operator_assignment_method(type: Type, operator: str) -> str:
22892336
if type.type.has_readable_member(inplace):
22902337
method = inplace
22912338
return method
2339+
2340+
2341+
def is_valid_inferred_type(typ: Type) -> bool:
2342+
"""Is an inferred type valid?
2343+
2344+
Examples of invalid types include the None type or a type with a None component.
2345+
"""
2346+
if is_same_type(typ, NoneTyp()):
2347+
return False
2348+
elif isinstance(typ, Instance):
2349+
for arg in typ.args:
2350+
if not is_valid_inferred_type(arg):
2351+
return False
2352+
elif isinstance(typ, TupleType):
2353+
for item in typ.items:
2354+
if not is_valid_inferred_type(item):
2355+
return False
2356+
return True

mypy/checkexpr.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from mypy.types import (
66
Type, AnyType, CallableType, Overloaded, NoneTyp, Void, TypeVarDef,
7-
TupleType, Instance, TypeVarType, TypeTranslator, ErasedType, FunctionLike, UnionType
7+
TupleType, Instance, TypeVarType, TypeTranslator, ErasedType, FunctionLike, UnionType,
8+
PartialType
89
)
910
from mypy.nodes import (
1011
NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr,
@@ -14,7 +15,7 @@
1415
ListComprehension, GeneratorExpr, SetExpr, MypyFile, Decorator,
1516
ConditionalExpr, ComparisonExpr, TempNode, SetComprehension,
1617
DictionaryComprehension, ComplexExpr, EllipsisExpr, LITERAL_TYPE,
17-
TypeAliasExpr, YieldExpr, BackquoteExpr
18+
TypeAliasExpr, YieldExpr, BackquoteExpr, ARG_POS
1819
)
1920
from mypy.errors import Errors
2021
from mypy.nodes import function_type
@@ -72,6 +73,11 @@ def analyze_ref_expr(self, e: RefExpr) -> Type:
7273
if isinstance(node, Var):
7374
# Variable reference.
7475
result = self.analyze_var_ref(node, e)
76+
if isinstance(result, PartialType):
77+
partial_types = self.chk.partial_types[-1]
78+
context = partial_types[node]
79+
self.msg.fail(messages.NEED_ANNOTATION_FOR_VAR, context)
80+
result = AnyType()
7581
elif isinstance(node, FuncDef):
7682
# Reference to a global function.
7783
result = function_type(node, self.named_type('builtins.function'))
@@ -110,13 +116,35 @@ def visit_call_expr(self, e: CallExpr) -> Type:
110116
if e.analyzed:
111117
# It's really a special form that only looks like a call.
112118
return self.accept(e.analyzed)
119+
self.try_infer_partial_type(e)
113120
self.accept(e.callee)
114121
# Access callee type directly, since accept may return the Any type
115122
# even if the type is known (in a dynamically typed function). This
116123
# way we get a more precise callee in dynamically typed functions.
117124
callee_type = self.chk.type_map[e.callee]
118125
return self.check_call_expr_with_callee_type(callee_type, e)
119126

127+
def try_infer_partial_type(self, e: CallExpr) -> None:
128+
partial_types = self.chk.partial_types[-1]
129+
if not partial_types:
130+
# Fast path leave -- no partial types in the current scope.
131+
return
132+
if isinstance(e.callee, MemberExpr) and isinstance(e.callee.expr, RefExpr):
133+
var = e.callee.expr.node
134+
if var in partial_types:
135+
var = cast(Var, var)
136+
typename = cast(Instance, var.type).type.fullname()
137+
methodname = e.callee.name
138+
if (((typename == 'builtins.list' and methodname == 'append') or
139+
(typename == 'builtins.set' and methodname == 'add'))
140+
and e.arg_kinds == [ARG_POS]):
141+
# We can infer a full type for a partial List type.
142+
# TODO: Don't infer argument expression twice.
143+
item_type = self.accept(e.args[0])
144+
if mypy.checker.is_valid_inferred_type(item_type):
145+
var.type = self.chk.named_generic_type(typename, [item_type])
146+
del partial_types[var]
147+
120148
def check_call_expr_with_callee_type(self, callee_type: Type,
121149
e: CallExpr) -> Type:
122150
"""Type check call expression.

mypy/checkmember.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mypy.types import (
66
Type, Instance, AnyType, TupleType, CallableType, FunctionLike, TypeVarDef,
7-
Overloaded, TypeVarType, TypeTranslator, UnionType
7+
Overloaded, TypeVarType, TypeTranslator, UnionType, PartialType
88
)
99
from mypy.nodes import TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context
1010
from mypy.nodes import ARG_POS, function_type, Decorator, OverloadedFuncDef

mypy/constraints.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mypy.types import (
66
CallableType, Type, TypeVisitor, UnboundType, AnyType, Void, NoneTyp, TypeVarType,
7-
Instance, TupleType, UnionType, Overloaded, ErasedType, is_named_instance
7+
Instance, TupleType, UnionType, Overloaded, ErasedType, PartialType, is_named_instance
88
)
99
from mypy.expandtype import expand_caller_var_args
1010
from mypy.maptype import map_instance_to_supertype
@@ -151,6 +151,12 @@ def visit_none_type(self, template: NoneTyp) -> List[Constraint]:
151151
def visit_erased_type(self, template: ErasedType) -> List[Constraint]:
152152
return []
153153

154+
# Errors
155+
156+
def visit_partial_type(self, template: PartialType) -> List[Constraint]:
157+
# We can't do anything useful with a partial type here.
158+
assert False, "Internal error"
159+
154160
# Non-trivial leaf type
155161

156162
def visit_type_var(self, template: TypeVarType) -> List[Constraint]:

mypy/erasetype.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from mypy.types import (
44
Type, TypeVisitor, UnboundType, ErrorType, AnyType, Void, NoneTyp,
55
Instance, TypeVarType, CallableType, TupleType, UnionType, Overloaded, ErasedType,
6-
TypeTranslator, TypeList
6+
TypeTranslator, TypeList, PartialType
77
)
88

99

@@ -46,6 +46,10 @@ def visit_erased_type(self, t: ErasedType) -> Type:
4646
# Should not get here.
4747
raise RuntimeError()
4848

49+
def visit_partial_type(self, t: PartialType) -> Type:
50+
# Should not get here.
51+
raise RuntimeError()
52+
4953
def visit_instance(self, t: Instance) -> Type:
5054
return Instance(t.type, [AnyType()] * len(t.args), t.line)
5155

mypy/expandtype.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from mypy.types import (
44
Type, Instance, CallableType, TypeVisitor, UnboundType, ErrorType, AnyType,
5-
Void, NoneTyp, TypeVarType, Overloaded, TupleType, UnionType, ErasedType, TypeList
5+
Void, NoneTyp, TypeVarType, Overloaded, TupleType, UnionType, ErasedType, TypeList,
6+
PartialType
67
)
78

89

@@ -92,6 +93,9 @@ def visit_tuple_type(self, t: TupleType) -> Type:
9293
def visit_union_type(self, t: UnionType) -> Type:
9394
return UnionType(self.expand_types(t.items), t.line)
9495

96+
def visit_partial_type(self, t: PartialType) -> Type:
97+
return t
98+
9599
def expand_types(self, types: List[Type]) -> List[Type]:
96100
a = [] # type: List[Type]
97101
for t in types:

mypy/join.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mypy.types import (
66
Type, AnyType, NoneTyp, Void, TypeVisitor, Instance, UnboundType,
77
ErrorType, TypeVarType, CallableType, TupleType, ErasedType, TypeList,
8-
UnionType, FunctionLike, Overloaded
8+
UnionType, FunctionLike, Overloaded, PartialType
99
)
1010
from mypy.maptype import map_instance_to_supertype
1111
from mypy.subtypes import is_subtype, is_equivalent, is_subtype_ignoring_tvars
@@ -196,6 +196,11 @@ def visit_tuple_type(self, t: TupleType) -> Type:
196196
else:
197197
return self.default(self.s)
198198

199+
def visit_partial_type(self, t: PartialType) -> Type:
200+
# We only have partial information so we can't decide the join result. We should
201+
# never get here.
202+
assert False, "Internal error"
203+
199204
def join(self, s: Type, t: Type) -> Type:
200205
return join_types(s, t)
201206

0 commit comments

Comments
 (0)