Skip to content

Commit bcf89b1

Browse files
authored
Refactor plugin system and special case TypedDict get and int.__pow__ (#3501)
Implement a general-purpose way of extending type inference of methods. Also special case TypedDict get and `int.__pow__`. Implement a new plugin system that can handle both module-level functions and methods. This an alternative to #2620 by @rowillia. I borrowed some test cases from that PR. This PR has a few major differences: * Use the plugin system instead of full special casing. * Don't support `d.get('x', {})` as it's not type safe. Once we have #2632 we can add support for this idiom safely. * Code like `f = foo.get` loses the special casing for get. Fixes #2612. Work towards #1240.
1 parent e86c0dc commit bcf89b1

11 files changed

+494
-115
lines changed

mypy/build.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from mypy.stats import dump_type_stats
4343
from mypy.types import Type
4444
from mypy.version import __version__
45+
from mypy.plugin import DefaultPlugin
4546

4647

4748
# We need to know the location of this file to load data, but
@@ -1505,8 +1506,9 @@ def type_check_first_pass(self) -> None:
15051506
if self.options.semantic_analysis_only:
15061507
return
15071508
with self.wrap_context():
1509+
plugin = DefaultPlugin(self.options.python_version)
15081510
self.type_checker = TypeChecker(manager.errors, manager.modules, self.options,
1509-
self.tree, self.xpath)
1511+
self.tree, self.xpath, plugin)
15101512
self.type_checker.check_first_pass()
15111513

15121514
def type_check_second_pass(self) -> bool:

mypy/checker.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from mypy.binder import ConditionalTypeBinder, get_declaration
5757
from mypy.meet import is_overlapping_types
5858
from mypy.options import Options
59+
from mypy.plugin import Plugin
5960

6061
from mypy import experiments
6162

@@ -127,8 +128,12 @@ class TypeChecker(NodeVisitor[None]):
127128
# directly or indirectly.
128129
module_refs = None # type: Set[str]
129130

131+
# Plugin that provides special type checking rules for specific library
132+
# functions such as open(), etc.
133+
plugin = None # type: Plugin
134+
130135
def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Options,
131-
tree: MypyFile, path: str) -> None:
136+
tree: MypyFile, path: str, plugin: Plugin) -> None:
132137
"""Construct a type checker.
133138
134139
Use errors to report type check errors.
@@ -139,7 +144,8 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option
139144
self.tree = tree
140145
self.path = path
141146
self.msg = MessageBuilder(errors, modules)
142-
self.expr_checker = mypy.checkexpr.ExpressionChecker(self, self.msg)
147+
self.plugin = plugin
148+
self.expr_checker = mypy.checkexpr.ExpressionChecker(self, self.msg, self.plugin)
143149
self.scope = Scope(tree)
144150
self.binder = ConditionalTypeBinder()
145151
self.globals = tree.names

mypy/checkexpr.py

Lines changed: 131 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from mypy.util import split_module_names
4545
from mypy.typevars import fill_typevars
4646
from mypy.visitor import ExpressionVisitor
47-
from mypy.funcplugins import get_function_plugin_callbacks, PluginCallback
47+
from mypy.plugin import Plugin, PluginContext, MethodSignatureHook
4848
from mypy.typeanal import make_optional_type
4949

5050
from mypy import experiments
@@ -105,17 +105,18 @@ class ExpressionChecker(ExpressionVisitor[Type]):
105105
type_context = None # type: List[Optional[Type]]
106106

107107
strfrm_checker = None # type: StringFormatterChecker
108-
function_plugins = None # type: Dict[str, PluginCallback]
108+
plugin = None # type: Plugin
109109

110110
def __init__(self,
111111
chk: 'mypy.checker.TypeChecker',
112-
msg: MessageBuilder) -> None:
112+
msg: MessageBuilder,
113+
plugin: Plugin) -> None:
113114
"""Construct an expression type checker."""
114115
self.chk = chk
115116
self.msg = msg
117+
self.plugin = plugin
116118
self.type_context = [None]
117119
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)
118-
self.function_plugins = get_function_plugin_callbacks(self.chk.options.python_version)
119120

120121
def visit_name_expr(self, e: NameExpr) -> Type:
121122
"""Type check a name expression.
@@ -208,11 +209,33 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
208209
isinstance(callee_type, CallableType)
209210
and callee_type.implicit):
210211
return self.msg.untyped_function_call(callee_type, e)
212+
# Figure out the full name of the callee for plugin loopup.
213+
object_type = None
211214
if not isinstance(e.callee, RefExpr):
212215
fullname = None
213216
else:
214217
fullname = e.callee.fullname
215-
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname)
218+
if (fullname is None
219+
and isinstance(e.callee, MemberExpr)
220+
and isinstance(callee_type, FunctionLike)):
221+
# For method calls we include the defining class for the method
222+
# in the full name (example: 'typing.Mapping.get').
223+
callee_expr_type = self.chk.type_map.get(e.callee.expr)
224+
info = None
225+
# TODO: Support fallbacks of other kinds of types as well?
226+
if isinstance(callee_expr_type, Instance):
227+
info = callee_expr_type.type
228+
elif isinstance(callee_expr_type, TypedDictType):
229+
info = callee_expr_type.fallback.type.get_containing_type_info(e.callee.name)
230+
if info:
231+
fullname = '{}.{}'.format(info.fullname(), e.callee.name)
232+
object_type = callee_expr_type
233+
# Apply plugin signature hook that may generate a better signature.
234+
signature_hook = self.plugin.get_method_signature_hook(fullname)
235+
if signature_hook:
236+
callee_type = self.apply_method_signature_hook(
237+
e, callee_type, object_type, signature_hook)
238+
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname, object_type)
216239
if isinstance(ret_type, UninhabitedType):
217240
self.chk.binder.unreachable()
218241
if not allow_none_return and isinstance(ret_type, NoneTyp):
@@ -351,8 +374,10 @@ def apply_function_plugin(self,
351374
formal_to_actual: List[List[int]],
352375
args: List[Expression],
353376
num_formals: int,
354-
fullname: Optional[str]) -> Type:
355-
"""Use special case logic to infer the return type for of a particular named function.
377+
fullname: Optional[str],
378+
object_type: Optional[Type],
379+
context: Context) -> Type:
380+
"""Use special case logic to infer the return type of a specific named function/method.
356381
357382
Return the inferred return type.
358383
"""
@@ -362,41 +387,90 @@ def apply_function_plugin(self,
362387
for actual in actuals:
363388
formal_arg_types[formal].append(arg_types[actual])
364389
formal_arg_exprs[formal].append(args[actual])
365-
return self.function_plugins[fullname](
366-
formal_arg_types, formal_arg_exprs, inferred_ret_type, self.chk.named_generic_type)
367-
368-
def check_call_expr_with_callee_type(self, callee_type: Type,
369-
e: CallExpr, callable_name: Optional[str]) -> Type:
390+
if object_type is None:
391+
# Apply function plugin
392+
callback = self.plugin.get_function_hook(fullname)
393+
assert callback is not None # Assume that caller ensures this
394+
return callback(formal_arg_types, formal_arg_exprs, inferred_ret_type,
395+
self.chk.named_generic_type)
396+
else:
397+
# Apply method plugin
398+
method_callback = self.plugin.get_method_hook(fullname)
399+
assert method_callback is not None # Assume that caller ensures this
400+
return method_callback(object_type, formal_arg_types, formal_arg_exprs,
401+
inferred_ret_type, self.create_plugin_context(context))
402+
403+
def apply_method_signature_hook(self, e: CallExpr, callee: FunctionLike, object_type: Type,
404+
signature_hook: MethodSignatureHook) -> FunctionLike:
405+
"""Apply a plugin hook that may infer a more precise signature for a method."""
406+
if isinstance(callee, CallableType):
407+
arg_kinds = e.arg_kinds
408+
arg_names = e.arg_names
409+
args = e.args
410+
num_formals = len(callee.arg_kinds)
411+
formal_to_actual = map_actuals_to_formals(
412+
arg_kinds, arg_names,
413+
callee.arg_kinds, callee.arg_names,
414+
lambda i: self.accept(args[i]))
415+
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
416+
for formal, actuals in enumerate(formal_to_actual):
417+
for actual in actuals:
418+
formal_arg_exprs[formal].append(args[actual])
419+
return signature_hook(object_type, formal_arg_exprs, callee,
420+
self.chk.named_generic_type)
421+
else:
422+
assert isinstance(callee, Overloaded)
423+
items = []
424+
for item in callee.items():
425+
adjusted = self.apply_method_signature_hook(e, item, object_type, signature_hook)
426+
assert isinstance(adjusted, CallableType)
427+
items.append(adjusted)
428+
return Overloaded(items)
429+
430+
def create_plugin_context(self, context: Context) -> PluginContext:
431+
return PluginContext(self.chk.named_generic_type, self.msg, context)
432+
433+
def check_call_expr_with_callee_type(self,
434+
callee_type: Type,
435+
e: CallExpr,
436+
callable_name: Optional[str],
437+
object_type: Optional[Type]) -> Type:
370438
"""Type check call expression.
371439
372440
The given callee type overrides the type of the callee
373441
expression.
374442
"""
375443
return self.check_call(callee_type, e.args, e.arg_kinds, e,
376444
e.arg_names, callable_node=e.callee,
377-
callable_name=callable_name)[0]
445+
callable_name=callable_name,
446+
object_type=object_type)[0]
378447

379448
def check_call(self, callee: Type, args: List[Expression],
380449
arg_kinds: List[int], context: Context,
381450
arg_names: List[str] = None,
382451
callable_node: Expression = None,
383452
arg_messages: MessageBuilder = None,
384-
callable_name: Optional[str] = None) -> Tuple[Type, Type]:
453+
callable_name: Optional[str] = None,
454+
object_type: Optional[Type] = None) -> Tuple[Type, Type]:
385455
"""Type check a call.
386456
387457
Also infer type arguments if the callee is a generic function.
388458
389459
Return (result type, inferred callee type).
390460
391461
Arguments:
392-
callee: type of the called value
393-
args: actual argument expressions
394-
arg_kinds: contains nodes.ARG_* constant for each argument in args
395-
describing whether the argument is positional, *arg, etc.
396-
arg_names: names of arguments (optional)
397-
callable_node: associate the inferred callable type to this node,
398-
if specified
399-
arg_messages: TODO
462+
callee: type of the called value
463+
args: actual argument expressions
464+
arg_kinds: contains nodes.ARG_* constant for each argument in args
465+
describing whether the argument is positional, *arg, etc.
466+
arg_names: names of arguments (optional)
467+
callable_node: associate the inferred callable type to this node,
468+
if specified
469+
arg_messages: TODO
470+
callable_name: Fully-qualified name of the function/method to call,
471+
or None if unavaiable (examples: 'builtins.open', 'typing.Mapping.get')
472+
object_type: If callable_name refers to a method, the type of the object
473+
on which the method is being called
400474
"""
401475
arg_messages = arg_messages or self.msg
402476
if isinstance(callee, CallableType):
@@ -443,10 +517,12 @@ def check_call(self, callee: Type, args: List[Expression],
443517
if callable_node:
444518
# Store the inferred callable type.
445519
self.chk.store_type(callable_node, callee)
446-
if callable_name in self.function_plugins:
520+
521+
if ((object_type is None and self.plugin.get_function_hook(callable_name))
522+
or (object_type is not None and self.plugin.get_method_hook(callable_name))):
447523
ret_type = self.apply_function_plugin(
448524
arg_types, callee.ret_type, arg_kinds, formal_to_actual,
449-
args, len(callee.arg_types), callable_name)
525+
args, len(callee.arg_types), callable_name, object_type, context)
450526
callee = callee.copy_modified(ret_type=ret_type)
451527
return callee.ret_type, callee
452528
elif isinstance(callee, Overloaded):
@@ -461,7 +537,9 @@ def check_call(self, callee: Type, args: List[Expression],
461537
callee, context,
462538
messages=arg_messages)
463539
return self.check_call(target, args, arg_kinds, context, arg_names,
464-
arg_messages=arg_messages)
540+
arg_messages=arg_messages,
541+
callable_name=callable_name,
542+
object_type=object_type)
465543
elif isinstance(callee, AnyType) or not self.chk.in_checked_function():
466544
self.infer_arg_types_in_context(None, args)
467545
return AnyType(), AnyType()
@@ -1295,8 +1373,16 @@ def check_op_local(self, method: str, base_type: Type, arg: Expression,
12951373
method_type = analyze_member_access(method, base_type, context, False, False, True,
12961374
self.named_type, self.not_ready_callback, local_errors,
12971375
original_type=base_type, chk=self.chk)
1376+
callable_name = None
1377+
object_type = None
1378+
if isinstance(base_type, Instance):
1379+
# TODO: Find out in which class the method was defined originally?
1380+
# TODO: Support non-Instance types.
1381+
callable_name = '{}.{}'.format(base_type.type.fullname(), method)
1382+
object_type = base_type
12981383
return self.check_call(method_type, [arg], [nodes.ARG_POS],
1299-
context, arg_messages=local_errors)
1384+
context, arg_messages=local_errors,
1385+
callable_name=callable_name, object_type=object_type)
13001386

13011387
def check_op(self, method: str, base_type: Type, arg: Expression,
13021388
context: Context,
@@ -1769,13 +1855,14 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
17691855
# an error, but returns the TypedDict type that matches the literal it found
17701856
# that would cause a second error when that TypedDict type is returned upstream
17711857
# to avoid the second error, we always return TypedDict type that was requested
1772-
if isinstance(self.type_context[-1], TypedDictType):
1858+
typeddict_context = self.find_typeddict_context(self.type_context[-1])
1859+
if typeddict_context:
17731860
self.check_typeddict_call_with_dict(
1774-
callee=self.type_context[-1],
1861+
callee=typeddict_context,
17751862
kwargs=e,
17761863
context=e
17771864
)
1778-
return self.type_context[-1].copy_modified()
1865+
return typeddict_context.copy_modified()
17791866

17801867
# Collect function arguments, watching out for **expr.
17811868
args = [] # type: List[Expression] # Regular "key: value"
@@ -1826,6 +1913,21 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
18261913
self.check_call(method, [arg], [nodes.ARG_POS], arg)
18271914
return rv
18281915

1916+
def find_typeddict_context(self, context: Type) -> Optional[TypedDictType]:
1917+
if isinstance(context, TypedDictType):
1918+
return context
1919+
elif isinstance(context, UnionType):
1920+
items = []
1921+
for item in context.items:
1922+
item_context = self.find_typeddict_context(item)
1923+
if item_context:
1924+
items.append(item_context)
1925+
if len(items) == 1:
1926+
# Only one union item is TypedDict, so use the context as it's unambiguous.
1927+
return items[0]
1928+
# No TypedDict type in context.
1929+
return None
1930+
18291931
def visit_lambda_expr(self, e: LambdaExpr) -> Type:
18301932
"""Type check lambda expression."""
18311933
inferred_type, type_override = self.infer_lambda_type_using_context(e)

0 commit comments

Comments
 (0)