Skip to content

Commit 53879ef

Browse files
JukkaLilevkivskyi
authored andcommitted
Add type checking plugin support for functions (#3299)
* Add type checking plugin support for functions The plugins allow implementing special-case logic for inferring the return type of certain functions with tricky signatures such as `open` in Python 3. Include plugins for `open` and `contextlib.contextmanager`. Some design considerations: - The plugins have direct access to mypy internals. The idea is that most plugins will be included with mypy so mypy maintainers can update the plugins as needed. - User-maintained plugins are currently not supported but could be added in the future. However, the intention is to not have a stable plugin API, at least initially. User-maintained plugins would have to track mypy internal API changes. Later on, we may decide to provide a more stable API if there seems to be a significant need. The preferred way would still be to keep plugins in the mypy repo. * Add test case for additional special cases * Fix handling of arguments other than simple positional ones Also add comments and some defensive checks.
1 parent a494197 commit 53879ef

File tree

4 files changed

+176
-6
lines changed

4 files changed

+176
-6
lines changed

mypy/checker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2223,8 +2223,12 @@ def visit_decorator(self, e: Decorator) -> None:
22232223
continue
22242224
dec = self.expr_checker.accept(d)
22252225
temp = self.temp_node(sig)
2226+
fullname = None
2227+
if isinstance(d, RefExpr):
2228+
fullname = d.fullname
22262229
sig, t2 = self.expr_checker.check_call(dec, [temp],
2227-
[nodes.ARG_POS], e)
2230+
[nodes.ARG_POS], e,
2231+
callable_name=fullname)
22282232
sig = cast(FunctionLike, sig)
22292233
sig = set_callable_name(sig, e.func)
22302234
e.var.type = sig

mypy/checkexpr.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from mypy.util import split_module_names
4545
from mypy.typevars import fill_typevars
4646
from mypy.visitor import ExpressionVisitor
47+
from mypy.funcplugins import get_function_plugin_callbacks, PluginCallback
4748

4849
from mypy import experiments
4950

@@ -103,6 +104,7 @@ class ExpressionChecker(ExpressionVisitor[Type]):
103104
type_context = None # type: List[Optional[Type]]
104105

105106
strfrm_checker = None # type: StringFormatterChecker
107+
function_plugins = None # type: Dict[str, PluginCallback]
106108

107109
def __init__(self,
108110
chk: 'mypy.checker.TypeChecker',
@@ -112,6 +114,7 @@ def __init__(self,
112114
self.msg = msg
113115
self.type_context = [None]
114116
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)
117+
self.function_plugins = get_function_plugin_callbacks(self.chk.options.python_version)
115118

116119
def visit_name_expr(self, e: NameExpr) -> Type:
117120
"""Type check a name expression.
@@ -198,7 +201,11 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
198201
isinstance(callee_type, CallableType)
199202
and callee_type.implicit):
200203
return self.msg.untyped_function_call(callee_type, e)
201-
ret_type = self.check_call_expr_with_callee_type(callee_type, e)
204+
if not isinstance(e.callee, RefExpr):
205+
fullname = None
206+
else:
207+
fullname = e.callee.fullname
208+
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname)
202209
if isinstance(ret_type, UninhabitedType):
203210
self.chk.binder.unreachable()
204211
if not allow_none_return and isinstance(ret_type, NoneTyp):
@@ -330,21 +337,44 @@ def try_infer_partial_type(self, e: CallExpr) -> None:
330337
list(full_item_types))
331338
del partial_types[var]
332339

340+
def apply_function_plugin(self,
341+
arg_types: List[Type],
342+
inferred_ret_type: Type,
343+
arg_kinds: List[int],
344+
formal_to_actual: List[List[int]],
345+
args: List[Expression],
346+
num_formals: int,
347+
fullname: Optional[str]) -> Type:
348+
"""Use special case logic to infer the return type for of a particular named function.
349+
350+
Return the inferred return type.
351+
"""
352+
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
353+
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
354+
for formal, actuals in enumerate(formal_to_actual):
355+
for actual in actuals:
356+
formal_arg_types[formal].append(arg_types[actual])
357+
formal_arg_exprs[formal].append(args[actual])
358+
return self.function_plugins[fullname](
359+
formal_arg_types, formal_arg_exprs, inferred_ret_type, self.chk.named_generic_type)
360+
333361
def check_call_expr_with_callee_type(self, callee_type: Type,
334-
e: CallExpr) -> Type:
362+
e: CallExpr, callable_name: Optional[str]) -> Type:
335363
"""Type check call expression.
336364
337365
The given callee type overrides the type of the callee
338366
expression.
339367
"""
340368
return self.check_call(callee_type, e.args, e.arg_kinds, e,
341-
e.arg_names, callable_node=e.callee)[0]
369+
e.arg_names, callable_node=e.callee,
370+
callable_name=callable_name)[0]
342371

343372
def check_call(self, callee: Type, args: List[Expression],
344373
arg_kinds: List[int], context: Context,
345374
arg_names: List[str] = None,
346375
callable_node: Expression = None,
347-
arg_messages: MessageBuilder = None) -> Tuple[Type, Type]:
376+
arg_messages: MessageBuilder = None,
377+
callable_name: Optional[str] = None) -> Tuple[Type, Type]:
348378
"""Type check a call.
349379
350380
Also infer type arguments if the callee is a generic function.
@@ -406,6 +436,11 @@ def check_call(self, callee: Type, args: List[Expression],
406436
if callable_node:
407437
# Store the inferred callable type.
408438
self.chk.store_type(callable_node, callee)
439+
if callable_name in self.function_plugins:
440+
ret_type = self.apply_function_plugin(
441+
arg_types, callee.ret_type, arg_kinds, formal_to_actual,
442+
args, len(callee.arg_types), callable_name)
443+
callee = callee.copy_modified(ret_type=ret_type)
409444
return callee.ret_type, callee
410445
elif isinstance(callee, Overloaded):
411446
# Type check arguments in empty context. They will be checked again

mypy/funcplugins.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Plugins that implement special type checking rules for individual functions.
2+
3+
The plugins infer better types for tricky functions such as "open".
4+
"""
5+
6+
from typing import Tuple, Dict, Callable, List
7+
8+
from mypy.nodes import Expression, StrExpr
9+
from mypy.types import Type, Instance, CallableType
10+
11+
12+
# A callback that infers the return type of a function with a special signature.
13+
#
14+
# A no-op callback would just return the inferred return type, but a useful callback
15+
# at least sometimes can infer a more precise type.
16+
PluginCallback = Callable[
17+
[
18+
List[List[Type]], # List of types caller provides for each formal argument
19+
List[List[Expression]], # Actual argument expressions for each formal argument
20+
Type, # Return type for call inferred using the regular signature
21+
Callable[[str, List[Type]], Type] # Callable for constructing a named instance type
22+
],
23+
Type # Return type inferred by the callback
24+
]
25+
26+
27+
def get_function_plugin_callbacks(python_version: Tuple[int, int]) -> Dict[str, PluginCallback]:
28+
"""Return all available function plugins for a given Python version."""
29+
if python_version[0] == 3:
30+
return {
31+
'builtins.open': open_callback,
32+
'contextlib.contextmanager': contextmanager_callback,
33+
}
34+
else:
35+
return {
36+
'contextlib.contextmanager': contextmanager_callback,
37+
}
38+
39+
40+
def open_callback(
41+
arg_types: List[List[Type]],
42+
args: List[List[Expression]],
43+
inferred_return_type: Type,
44+
named_generic_type: Callable[[str, List[Type]], Type]) -> Type:
45+
"""Infer a better return type for 'open'.
46+
47+
Infer IO[str] or IO[bytes] as the return value if the mode argument is not
48+
given or is a literal.
49+
"""
50+
mode = None
51+
if not arg_types or len(arg_types[1]) != 1:
52+
mode = 'r'
53+
elif isinstance(args[1][0], StrExpr):
54+
mode = args[1][0].value
55+
if mode is not None:
56+
assert isinstance(inferred_return_type, Instance)
57+
if 'b' in mode:
58+
arg = named_generic_type('builtins.bytes', [])
59+
else:
60+
arg = named_generic_type('builtins.str', [])
61+
return Instance(inferred_return_type.type, [arg])
62+
return inferred_return_type
63+
64+
65+
def contextmanager_callback(
66+
arg_types: List[List[Type]],
67+
args: List[List[Expression]],
68+
inferred_return_type: Type,
69+
named_generic_type: Callable[[str, List[Type]], Type]) -> Type:
70+
"""Infer a better return type for 'contextlib.contextmanager'."""
71+
# Be defensive, just in case.
72+
if arg_types and len(arg_types[0]) == 1:
73+
arg_type = arg_types[0][0]
74+
if isinstance(arg_type, CallableType) and isinstance(inferred_return_type, CallableType):
75+
# The stub signature doesn't preserve information about arguments so
76+
# add them back here.
77+
return inferred_return_type.copy_modified(
78+
arg_types=arg_type.arg_types,
79+
arg_kinds=arg_type.arg_kinds,
80+
arg_names=arg_type.arg_names)
81+
return inferred_return_type

test-data/unit/pythoneval.test

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,33 @@ f.write('x')
399399
f.write(b'x')
400400
f.foobar()
401401
[out]
402-
_program.py:4: error: IO[Any] has no attribute "foobar"
402+
_program.py:3: error: Argument 1 to "write" of "IO" has incompatible type "bytes"; expected "str"
403+
_program.py:4: error: IO[str] has no attribute "foobar"
404+
405+
[case testOpenReturnTypeInference]
406+
reveal_type(open('x'))
407+
reveal_type(open('x', 'r'))
408+
reveal_type(open('x', 'rb'))
409+
mode = 'rb'
410+
reveal_type(open('x', mode))
411+
[out]
412+
_program.py:1: error: Revealed type is 'typing.IO[builtins.str]'
413+
_program.py:2: error: Revealed type is 'typing.IO[builtins.str]'
414+
_program.py:3: error: Revealed type is 'typing.IO[builtins.bytes]'
415+
_program.py:5: error: Revealed type is 'typing.IO[Any]'
416+
417+
[case testOpenReturnTypeInferenceSpecialCases]
418+
reveal_type(open())
419+
reveal_type(open(mode='rb', file='x'))
420+
reveal_type(open(file='x', mode='rb'))
421+
mode = 'rb'
422+
reveal_type(open(mode=mode, file='r'))
423+
[out]
424+
_testOpenReturnTypeInferenceSpecialCases.py:1: error: Revealed type is 'typing.IO[builtins.str]'
425+
_testOpenReturnTypeInferenceSpecialCases.py:1: error: Too few arguments for "open"
426+
_testOpenReturnTypeInferenceSpecialCases.py:2: error: Revealed type is 'typing.IO[builtins.bytes]'
427+
_testOpenReturnTypeInferenceSpecialCases.py:3: error: Revealed type is 'typing.IO[builtins.bytes]'
428+
_testOpenReturnTypeInferenceSpecialCases.py:5: error: Revealed type is 'typing.IO[Any]'
403429

404430
[case testGenericPatterns]
405431
from typing import Pattern
@@ -1286,3 +1312,27 @@ a[1] = 2, 'y'
12861312
a[:] = [('z', 3)]
12871313
[out]
12881314
_program.py:4: error: Incompatible types in assignment (expression has type "Tuple[int, str]", target has type "Tuple[str, int]")
1315+
1316+
[case testContextManager]
1317+
import contextlib
1318+
from contextlib import contextmanager
1319+
from typing import Iterator
1320+
1321+
@contextmanager
1322+
def f(x: int) -> Iterator[str]:
1323+
yield 'foo'
1324+
1325+
@contextlib.contextmanager
1326+
def g(*x: str) -> Iterator[int]:
1327+
yield 1
1328+
1329+
reveal_type(f)
1330+
reveal_type(g)
1331+
1332+
with f('') as s:
1333+
reveal_type(s)
1334+
[out]
1335+
_program.py:13: error: Revealed type is 'def (x: builtins.int) -> contextlib.GeneratorContextManager[builtins.str*]'
1336+
_program.py:14: error: Revealed type is 'def (*x: builtins.str) -> contextlib.GeneratorContextManager[builtins.int*]'
1337+
_program.py:16: error: Argument 1 to "f" has incompatible type "str"; expected "int"
1338+
_program.py:17: error: Revealed type is 'builtins.str*'

0 commit comments

Comments
 (0)