From 2890aa3d3e12eb9249faaa59507cf5c4f90b3e69 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 2 May 2017 13:47:11 +0100 Subject: [PATCH 1/3] Add type checking plugin support for functions The plugins allow implementing special-case logic for inferring the return type of certain functions with tricky signatures such as `open` in Python 3. Include plugins for `open` and `contextlib.contextmanager`. Some design considerations: - The plugins have direct access to mypy internals. The idea is that most plugins will be included with mypy so mypy maintainers can update the plugins as needed. - User-maintained plugins are currently not supported but could be added in the future. However, the intention is to not have a stable plugin API, at least initially. User-maintained plugins would have to track mypy internal API changes. Later on, we may decide to provide a more stable API if there seems to be a significant need. The preferred way would still be to keep plugins in the mypy repo. --- mypy/checker.py | 6 ++- mypy/checkexpr.py | 41 ++++++++++++++++++-- mypy/funcplugins.py | 71 ++++++++++++++++++++++++++++++++++ test-data/unit/pythoneval.test | 39 ++++++++++++++++++- 4 files changed, 151 insertions(+), 6 deletions(-) create mode 100644 mypy/funcplugins.py diff --git a/mypy/checker.py b/mypy/checker.py index 5cbb884d2898..c3326279d48c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2219,8 +2219,12 @@ def visit_decorator(self, e: Decorator) -> None: continue dec = self.expr_checker.accept(d) temp = self.temp_node(sig) + fullname = None + if isinstance(d, RefExpr): + fullname = d.fullname sig, t2 = self.expr_checker.check_call(dec, [temp], - [nodes.ARG_POS], e) + [nodes.ARG_POS], e, + callable_name=fullname) sig = cast(FunctionLike, sig) sig = set_callable_name(sig, e.func) e.var.type = sig diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index b19a0a1bcfd2..ab727b5a9a87 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -44,6 +44,7 @@ from mypy.util import split_module_names from mypy.typevars import fill_typevars from mypy.visitor import ExpressionVisitor +from mypy.funcplugins import get_function_plugin_callbacks, PluginCallback from mypy import experiments @@ -103,6 +104,7 @@ class ExpressionChecker(ExpressionVisitor[Type]): type_context = None # type: List[Optional[Type]] strfrm_checker = None # type: StringFormatterChecker + function_plugins = None # type: Dict[str, PluginCallback] def __init__(self, chk: 'mypy.checker.TypeChecker', @@ -112,6 +114,7 @@ def __init__(self, self.msg = msg self.type_context = [None] self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg) + self.function_plugins = get_function_plugin_callbacks(self.chk.options.python_version) def visit_name_expr(self, e: NameExpr) -> Type: """Type check a name expression. @@ -198,7 +201,11 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type: isinstance(callee_type, CallableType) and callee_type.implicit): return self.msg.untyped_function_call(callee_type, e) - ret_type = self.check_call_expr_with_callee_type(callee_type, e) + if not isinstance(e.callee, RefExpr): + fullname = None + else: + fullname = e.callee.fullname + ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname) if isinstance(ret_type, UninhabitedType): self.chk.binder.unreachable() if not allow_none_return and isinstance(ret_type, NoneTyp): @@ -330,21 +337,42 @@ def try_infer_partial_type(self, e: CallExpr) -> None: list(full_item_types)) del partial_types[var] + def apply_function_plugin(self, + arg_types: List[Type], + inferred_ret_type: Type, + arg_kinds: List[int], + formal_to_actual: List[List[int]], + args: List[Expression], + num_formals: int, + fullname: Optional[str]) -> Type: + """Use special case logic to infer the return type for of a particular named function. + + Return the inferred return type. + """ + formal_arg_types = [None] * num_formals # type: List[Optional[Type]] + for formal, actuals in enumerate(formal_to_actual): + for actual in actuals: + formal_arg_types[formal] = arg_types[actual] + return self.function_plugins[fullname]( + formal_arg_types, inferred_ret_type, args, self.chk.named_generic_type) + def check_call_expr_with_callee_type(self, callee_type: Type, - e: CallExpr) -> Type: + e: CallExpr, callable_name: Optional[str]) -> Type: """Type check call expression. The given callee type overrides the type of the callee expression. """ return self.check_call(callee_type, e.args, e.arg_kinds, e, - e.arg_names, callable_node=e.callee)[0] + e.arg_names, callable_node=e.callee, + callable_name=callable_name)[0] def check_call(self, callee: Type, args: List[Expression], arg_kinds: List[int], context: Context, arg_names: List[str] = None, callable_node: Expression = None, - arg_messages: MessageBuilder = None) -> Tuple[Type, Type]: + arg_messages: MessageBuilder = None, + callable_name: Optional[str] = None) -> Tuple[Type, Type]: """Type check a call. Also infer type arguments if the callee is a generic function. @@ -406,6 +434,11 @@ def check_call(self, callee: Type, args: List[Expression], if callable_node: # Store the inferred callable type. self.chk.store_type(callable_node, callee) + if callable_name in self.function_plugins: + ret_type = self.apply_function_plugin( + arg_types, callee.ret_type, arg_kinds, formal_to_actual, + args, len(callee.arg_types), callable_name) + callee = callee.copy_modified(ret_type=ret_type) return callee.ret_type, callee elif isinstance(callee, Overloaded): # Type check arguments in empty context. They will be checked again diff --git a/mypy/funcplugins.py b/mypy/funcplugins.py new file mode 100644 index 000000000000..c27c064c9826 --- /dev/null +++ b/mypy/funcplugins.py @@ -0,0 +1,71 @@ +"""Plugins that implement special type checking rules for individual functions. + +The plugins infer better types for tricky functions such as "open". +""" + +from typing import Tuple, Dict, Callable, List + +from mypy.nodes import Expression, StrExpr +from mypy.types import Type, Instance, CallableType + + +PluginCallback = Callable[[List[Type], + Type, + List[Expression], + Callable[[str, List[Type]], Type]], + Type] + + +def get_function_plugin_callbacks(python_version: Tuple[int, int]) -> Dict[str, PluginCallback]: + """Return all available function plugins for a given Python version.""" + if python_version[0] == 3: + return { + 'builtins.open': open_callback, + 'contextlib.contextmanager': contextmanager_callback, + } + else: + return { + 'contextlib.contextmanager': contextmanager_callback, + } + + +def open_callback( + arg_types: List[Type], + inferred_return_type: Type, + args: List[Expression], + named_generic_type: Callable[[str, List[Type]], Type]) -> Type: + """Infer a better return type for 'open'. + + Infer IO[str] or IO[bytes] as the return value if the mode argument is not + given or is a literal. + """ + mode = None + if arg_types[1] is None: + mode = 'r' + elif isinstance(args[1], StrExpr): + mode = args[1].value + if mode is not None: + assert isinstance(inferred_return_type, Instance) + if 'b' in mode: + arg = named_generic_type('builtins.bytes', []) + else: + arg = named_generic_type('builtins.str', []) + return Instance(inferred_return_type.type, [arg]) + return inferred_return_type + + +def contextmanager_callback( + arg_types: List[Type], + inferred_return_type: Type, + args: List[Expression], + named_generic_type: Callable[[str, List[Type]], Type]) -> Type: + """Infer a better return type for 'contextlib.contextmanager'.""" + arg_type = arg_types[0] + if isinstance(arg_type, CallableType) and isinstance(inferred_return_type, CallableType): + # The stub signature doesn't preserve information about arguments so + # add them back here. + return inferred_return_type.copy_modified( + arg_types=arg_type.arg_types, + arg_kinds=arg_type.arg_kinds, + arg_names=arg_type.arg_names) + return inferred_return_type diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index c7e24e05bfd3..170a8fc8f4a8 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -399,7 +399,20 @@ f.write('x') f.write(b'x') f.foobar() [out] -_program.py:4: error: IO[Any] has no attribute "foobar" +_program.py:3: error: Argument 1 to "write" of "IO" has incompatible type "bytes"; expected "str" +_program.py:4: error: IO[str] has no attribute "foobar" + +[case testOpenReturnTypeInference] +reveal_type(open('x')) +reveal_type(open('x', 'r')) +reveal_type(open('x', 'rb')) +mode = 'rb' +reveal_type(open('x', mode)) +[out] +_program.py:1: error: Revealed type is 'typing.IO[builtins.str]' +_program.py:2: error: Revealed type is 'typing.IO[builtins.str]' +_program.py:3: error: Revealed type is 'typing.IO[builtins.bytes]' +_program.py:5: error: Revealed type is 'typing.IO[Any]' [case testGenericPatterns] from typing import Pattern @@ -1286,3 +1299,27 @@ a[1] = 2, 'y' a[:] = [('z', 3)] [out] _program.py:4: error: Incompatible types in assignment (expression has type "Tuple[int, str]", target has type "Tuple[str, int]") + +[case testContextManager] +import contextlib +from contextlib import contextmanager +from typing import Iterator + +@contextmanager +def f(x: int) -> Iterator[str]: + yield 'foo' + +@contextlib.contextmanager +def g(*x: str) -> Iterator[int]: + yield 1 + +reveal_type(f) +reveal_type(g) + +with f('') as s: + reveal_type(s) +[out] +_program.py:13: error: Revealed type is 'def (x: builtins.int) -> contextlib.GeneratorContextManager[builtins.str*]' +_program.py:14: error: Revealed type is 'def (*x: builtins.str) -> contextlib.GeneratorContextManager[builtins.int*]' +_program.py:16: error: Argument 1 to "f" has incompatible type "str"; expected "int" +_program.py:17: error: Revealed type is 'builtins.str*' From 9703800475a26abe02186b4b662c67f7a3e29f6c Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 16 May 2017 17:49:32 -0700 Subject: [PATCH 2/3] Add test case for additional special cases --- test-data/unit/pythoneval.test | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 170a8fc8f4a8..51ac27acecd0 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -414,6 +414,19 @@ _program.py:2: error: Revealed type is 'typing.IO[builtins.str]' _program.py:3: error: Revealed type is 'typing.IO[builtins.bytes]' _program.py:5: error: Revealed type is 'typing.IO[Any]' +[case testOpenReturnTypeInferenceSpecialCases] +reveal_type(open()) +reveal_type(open(mode='rb', file='x')) +reveal_type(open(file='x', mode='rb')) +mode = 'rb' +reveal_type(open(mode=mode, file='r')) +[out] +_testOpenReturnTypeInferenceSpecialCases.py:1: error: Revealed type is 'typing.IO[builtins.str]' +_testOpenReturnTypeInferenceSpecialCases.py:1: error: Too few arguments for "open" +_testOpenReturnTypeInferenceSpecialCases.py:2: error: Revealed type is 'typing.IO[builtins.bytes]' +_testOpenReturnTypeInferenceSpecialCases.py:3: error: Revealed type is 'typing.IO[builtins.bytes]' +_testOpenReturnTypeInferenceSpecialCases.py:5: error: Revealed type is 'typing.IO[Any]' + [case testGenericPatterns] from typing import Pattern import re From 02ad7b6caa31026aa9d9a6ed475a1846f6485995 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 16 May 2017 17:56:10 -0700 Subject: [PATCH 3/3] Fix handling of arguments other than simple positional ones Also add comments and some defensive checks. --- mypy/checkexpr.py | 8 +++++--- mypy/funcplugins.py | 50 +++++++++++++++++++++++++++------------------ 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ab727b5a9a87..016668b46105 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -349,12 +349,14 @@ def apply_function_plugin(self, Return the inferred return type. """ - formal_arg_types = [None] * num_formals # type: List[Optional[Type]] + formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]] + formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]] for formal, actuals in enumerate(formal_to_actual): for actual in actuals: - formal_arg_types[formal] = arg_types[actual] + formal_arg_types[formal].append(arg_types[actual]) + formal_arg_exprs[formal].append(args[actual]) return self.function_plugins[fullname]( - formal_arg_types, inferred_ret_type, args, self.chk.named_generic_type) + formal_arg_types, formal_arg_exprs, inferred_ret_type, self.chk.named_generic_type) def check_call_expr_with_callee_type(self, callee_type: Type, e: CallExpr, callable_name: Optional[str]) -> Type: diff --git a/mypy/funcplugins.py b/mypy/funcplugins.py index c27c064c9826..7d5c25248e0a 100644 --- a/mypy/funcplugins.py +++ b/mypy/funcplugins.py @@ -9,11 +9,19 @@ from mypy.types import Type, Instance, CallableType -PluginCallback = Callable[[List[Type], - Type, - List[Expression], - Callable[[str, List[Type]], Type]], - Type] +# A callback that infers the return type of a function with a special signature. +# +# A no-op callback would just return the inferred return type, but a useful callback +# at least sometimes can infer a more precise type. +PluginCallback = Callable[ + [ + List[List[Type]], # List of types caller provides for each formal argument + List[List[Expression]], # Actual argument expressions for each formal argument + Type, # Return type for call inferred using the regular signature + Callable[[str, List[Type]], Type] # Callable for constructing a named instance type + ], + Type # Return type inferred by the callback +] def get_function_plugin_callbacks(python_version: Tuple[int, int]) -> Dict[str, PluginCallback]: @@ -30,9 +38,9 @@ def get_function_plugin_callbacks(python_version: Tuple[int, int]) -> Dict[str, def open_callback( - arg_types: List[Type], + arg_types: List[List[Type]], + args: List[List[Expression]], inferred_return_type: Type, - args: List[Expression], named_generic_type: Callable[[str, List[Type]], Type]) -> Type: """Infer a better return type for 'open'. @@ -40,10 +48,10 @@ def open_callback( given or is a literal. """ mode = None - if arg_types[1] is None: + if not arg_types or len(arg_types[1]) != 1: mode = 'r' - elif isinstance(args[1], StrExpr): - mode = args[1].value + elif isinstance(args[1][0], StrExpr): + mode = args[1][0].value if mode is not None: assert isinstance(inferred_return_type, Instance) if 'b' in mode: @@ -55,17 +63,19 @@ def open_callback( def contextmanager_callback( - arg_types: List[Type], + arg_types: List[List[Type]], + args: List[List[Expression]], inferred_return_type: Type, - args: List[Expression], named_generic_type: Callable[[str, List[Type]], Type]) -> Type: """Infer a better return type for 'contextlib.contextmanager'.""" - arg_type = arg_types[0] - if isinstance(arg_type, CallableType) and isinstance(inferred_return_type, CallableType): - # The stub signature doesn't preserve information about arguments so - # add them back here. - return inferred_return_type.copy_modified( - arg_types=arg_type.arg_types, - arg_kinds=arg_type.arg_kinds, - arg_names=arg_type.arg_names) + # Be defensive, just in case. + if arg_types and len(arg_types[0]) == 1: + arg_type = arg_types[0][0] + if isinstance(arg_type, CallableType) and isinstance(inferred_return_type, CallableType): + # The stub signature doesn't preserve information about arguments so + # add them back here. + return inferred_return_type.copy_modified( + arg_types=arg_type.arg_types, + arg_kinds=arg_type.arg_kinds, + arg_names=arg_type.arg_names) return inferred_return_type