diff --git a/mypy/build.py b/mypy/build.py index e850b4598c42..defe2fb38914 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -172,8 +172,9 @@ def build(sources: List[BuildSource], lib_path.insert(0, alt_lib_path) reports = Reports(data_dir, options.report_dirs) - source_set = BuildSourceSet(sources) + errors = Errors(options.show_error_context, options.show_column_numbers) + plugin = load_plugins(options, errors) # Construct a build manager object to hold state during the build. # @@ -184,9 +185,8 @@ def build(sources: List[BuildSource], reports=reports, options=options, version_id=__version__, - plugin=DefaultPlugin(options.python_version)) - - manager.plugin = load_custom_plugins(manager.plugin, options, manager.errors) + plugin=plugin, + errors=errors) try: graph = dispatch(sources, manager) @@ -337,13 +337,14 @@ def import_priority(imp: ImportBase, toplevel_priority: int) -> int: return toplevel_priority -def load_custom_plugins(default_plugin: Plugin, options: Options, errors: Errors) -> Plugin: - """Load custom plugins if any are configured. +def load_plugins(options: Options, errors: Errors) -> Plugin: + """Load all configured plugins. - Return a plugin that chains all custom plugins (if any) and falls - back to default_plugin. + Return a plugin that encapsulates all plugins chained together. Always + at least include the default plugin (it's last in the chain). """ + default_plugin = DefaultPlugin(options) # type: Plugin if not options.config_file: return default_plugin @@ -355,8 +356,8 @@ def plugin_error(message: str) -> None: errors.report(line, 0, message) errors.raise_error() + custom_plugins = [] # type: List[Plugin] errors.set_file(options.config_file, None) - custom_plugins = [] for plugin_path in options.plugins: # Plugin paths are relative to the config file location. plugin_path = os.path.join(os.path.dirname(options.config_file), plugin_path) @@ -395,15 +396,12 @@ def plugin_error(message: str) -> None: 'Return value of "plugin" must be a subclass of "mypy.plugin.Plugin" ' '(in {})'.format(plugin_path)) try: - custom_plugins.append(plugin_type(options.python_version)) + custom_plugins.append(plugin_type(options)) except Exception: print('Error constructing plugin instance of {}\n'.format(plugin_type.__name__)) raise # Propagate to display traceback - if not custom_plugins: - return default_plugin - else: - # Custom plugins take precendence over built-in plugins. - return ChainedPlugin(options.python_version, custom_plugins + [default_plugin]) + # Custom plugins take precedence over the default plugin. + return ChainedPlugin(options, custom_plugins + [default_plugin]) def find_config_file_line_number(path: str, section: str, setting_name: str) -> int: @@ -447,12 +445,12 @@ class BuildManager: semantic_analyzer_pass3: Semantic analyzer, pass 3 all_types: Map {Expression: Type} collected from all modules - errors: Used for reporting all errors options: Build options missing_modules: Set of modules that could not be imported encountered so far stale_modules: Set of modules that needed to be rechecked version_id: The current mypy version (based on commit id when possible) plugin: Active mypy plugin(s) + errors: Used for reporting all errors """ def __init__(self, data_dir: str, @@ -462,10 +460,11 @@ def __init__(self, data_dir: str, reports: Reports, options: Options, version_id: str, - plugin: Plugin) -> None: + plugin: Plugin, + errors: Errors) -> None: self.start_time = time.time() self.data_dir = data_dir - self.errors = Errors(options.show_error_context, options.show_column_numbers) + self.errors = errors self.errors.set_ignore_prefix(ignore_prefix) self.lib_path = tuple(lib_path) self.source_set = source_set @@ -474,8 +473,9 @@ def __init__(self, data_dir: str, self.version_id = version_id self.modules = {} # type: Dict[str, MypyFile] self.missing_modules = set() # type: Set[str] + self.plugin = plugin self.semantic_analyzer = SemanticAnalyzer(self.modules, self.missing_modules, - lib_path, self.errors) + lib_path, self.errors, self.plugin) self.modules = self.semantic_analyzer.modules self.semantic_analyzer_pass3 = ThirdPass(self.modules, self.errors) self.all_types = {} # type: Dict[Expression, Type] diff --git a/mypy/checker.py b/mypy/checker.py index 98bb35474bf9..ed6d0dc0196d 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -57,7 +57,7 @@ from mypy.binder import ConditionalTypeBinder, get_declaration from mypy.meet import is_overlapping_types from mypy.options import Options -from mypy.plugin import Plugin +from mypy.plugin import Plugin, CheckerPluginInterface from mypy import experiments @@ -80,7 +80,7 @@ ]) -class TypeChecker(NodeVisitor[None]): +class TypeChecker(NodeVisitor[None], CheckerPluginInterface): """Mypy type checker. Type check mypy source files that have been semantically analyzed. diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index c78919791ffa..3cdc3006d827 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -45,7 +45,7 @@ from mypy.util import split_module_names from mypy.typevars import fill_typevars from mypy.visitor import ExpressionVisitor -from mypy.plugin import Plugin, PluginContext, MethodSignatureHook +from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext from mypy.typeanal import make_optional_type from mypy import experiments @@ -380,6 +380,13 @@ def apply_function_plugin(self, context: Context) -> Type: """Use special case logic to infer the return type of a specific named function/method. + Caller must ensure that a plugin hook exists. There are two different cases: + + - If object_type is None, the caller must ensure that a function hook exists + for fullname. + - If object_type is not None, the caller must ensure that a method hook exists + for fullname. + Return the inferred return type. """ formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]] @@ -392,17 +399,21 @@ def apply_function_plugin(self, # Apply function plugin callback = self.plugin.get_function_hook(fullname) assert callback is not None # Assume that caller ensures this - return callback(formal_arg_types, formal_arg_exprs, inferred_ret_type, - self.chk.named_generic_type) + return callback( + FunctionContext(formal_arg_types, inferred_ret_type, formal_arg_exprs, + context, self.chk)) else: # Apply method plugin method_callback = self.plugin.get_method_hook(fullname) assert method_callback is not None # Assume that caller ensures this - return method_callback(object_type, formal_arg_types, formal_arg_exprs, - inferred_ret_type, self.create_plugin_context(context)) - - def apply_method_signature_hook(self, e: CallExpr, callee: FunctionLike, object_type: Type, - signature_hook: MethodSignatureHook) -> FunctionLike: + return method_callback( + MethodContext(object_type, formal_arg_types, + inferred_ret_type, formal_arg_exprs, + context, self.chk)) + + def apply_method_signature_hook( + self, e: CallExpr, callee: FunctionLike, object_type: Type, + signature_hook: Callable[[MethodSigContext], CallableType]) -> FunctionLike: """Apply a plugin hook that may infer a more precise signature for a method.""" if isinstance(callee, CallableType): arg_kinds = e.arg_kinds @@ -417,8 +428,8 @@ def apply_method_signature_hook(self, e: CallExpr, callee: FunctionLike, object_ for formal, actuals in enumerate(formal_to_actual): for actual in actuals: formal_arg_exprs[formal].append(args[actual]) - return signature_hook(object_type, formal_arg_exprs, callee, - self.chk.named_generic_type) + return signature_hook( + MethodSigContext(object_type, formal_arg_exprs, callee, e, self.chk)) else: assert isinstance(callee, Overloaded) items = [] @@ -428,9 +439,6 @@ def apply_method_signature_hook(self, e: CallExpr, callee: FunctionLike, object_ items.append(adjusted) return Overloaded(items) - def create_plugin_context(self, context: Context) -> PluginContext: - return PluginContext(self.chk.named_generic_type, self.msg, context) - def check_call_expr_with_callee_type(self, callee_type: Type, e: CallExpr, @@ -475,6 +483,8 @@ def check_call(self, callee: Type, args: List[Expression], """ arg_messages = arg_messages or self.msg if isinstance(callee, CallableType): + if callable_name is None and callee.name: + callable_name = callee.name if (isinstance(callable_node, RefExpr) and callable_node.fullname in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag')): diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 66949144eed8..d3f58a3a5917 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -17,6 +17,7 @@ from mypy.expandtype import expand_type_by_instance, expand_type, freshen_function_type_vars from mypy.infer import infer_type_arguments from mypy.typevars import fill_typevars +from mypy.plugin import Plugin, AttributeContext from mypy import messages from mypy import subtypes MYPY = False @@ -36,8 +37,8 @@ def analyze_member_access(name: str, not_ready_callback: Callable[[str, Context], None], msg: MessageBuilder, *, original_type: Type, - override_info: TypeInfo = None, - chk: 'mypy.checker.TypeChecker' = None) -> Type: + chk: 'mypy.checker.TypeChecker', + override_info: TypeInfo = None) -> Type: """Return the type of attribute `name` of typ. This is a general operation that supports various different variations: @@ -77,7 +78,7 @@ def analyze_member_access(name: str, assert isinstance(method, OverloadedFuncDef) first_item = cast(Decorator, method.items[0]) return analyze_var(name, first_item.var, typ, info, node, is_lvalue, msg, - original_type, not_ready_callback) + original_type, not_ready_callback, chk=chk) if is_lvalue: msg.cant_assign_to_method(node) signature = function_type(method, builtin_type('builtins.function')) @@ -102,7 +103,7 @@ def analyze_member_access(name: str, # The base object has dynamic type. return AnyType() elif isinstance(typ, NoneTyp): - if chk and chk.should_suppress_optional_error([typ]): + if chk.should_suppress_optional_error([typ]): return AnyType() # The only attribute NoneType has are those it inherits from object return analyze_member_access(name, builtin_type('builtins.object'), node, is_lvalue, @@ -200,7 +201,7 @@ def analyze_member_access(name: str, is_operator, builtin_type, not_ready_callback, msg, original_type=original_type, chk=chk) - if chk and chk.should_suppress_optional_error([typ]): + if chk.should_suppress_optional_error([typ]): return AnyType() return msg.has_no_attr(original_type, typ, name, node) @@ -228,7 +229,7 @@ def analyze_member_var_access(name: str, itype: Instance, info: TypeInfo, if isinstance(v, Var): return analyze_var(name, v, itype, info, node, is_lvalue, msg, - original_type, not_ready_callback) + original_type, not_ready_callback, chk=chk) elif isinstance(v, FuncDef): assert False, "Did not expect a function" elif not v and name not in ['__getattr__', '__setattr__', '__getattribute__']: @@ -270,7 +271,8 @@ def analyze_member_var_access(name: str, itype: Instance, info: TypeInfo, def analyze_var(name: str, var: Var, itype: Instance, info: TypeInfo, node: Context, is_lvalue: bool, msg: MessageBuilder, original_type: Type, - not_ready_callback: Callable[[str, Context], None]) -> Type: + not_ready_callback: Callable[[str, Context], None], *, + chk: 'mypy.checker.TypeChecker') -> Type: """Analyze access to an attribute via a Var node. This is conceptually part of analyze_member_access and the arguments are similar. @@ -289,6 +291,7 @@ def analyze_var(name: str, var: Var, itype: Instance, info: TypeInfo, node: Cont msg.read_only_property(name, info, node) if is_lvalue and var.is_classvar: msg.cant_assign_to_classvar(name, node) + result = t if var.is_initialized_in_class and isinstance(t, FunctionLike) and not t.is_type_obj(): if is_lvalue: if var.is_property: @@ -308,15 +311,19 @@ def analyze_var(name: str, var: Var, itype: Instance, info: TypeInfo, node: Cont # A property cannot have an overloaded type => the cast # is fine. assert isinstance(signature, CallableType) - return signature.ret_type + result = signature.ret_type else: - return signature - return t + result = signature else: if not var.is_ready: not_ready_callback(var.name(), node) # Implicit 'Any' type. - return AnyType() + result = AnyType() + fullname = '{}.{}'.format(var.info.fullname(), name) + hook = chk.plugin.get_attribute_hook(fullname) + if hook: + result = hook(AttributeContext(original_type, result, node, chk)) + return result def freeze_type_vars(member_type: Type) -> None: diff --git a/mypy/plugin.py b/mypy/plugin.py index 7acd4d0b29a5..f94790a06e96 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -1,62 +1,101 @@ +"""Plugin system for extending mypy.""" + +from abc import abstractmethod from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context from mypy.types import ( Type, Instance, CallableType, TypedDictType, UnionType, NoneTyp, FunctionLike, TypeVarType, - AnyType + AnyType, TypeList, UnboundType ) from mypy.messages import MessageBuilder +from mypy.options import Options + + +class AnalyzerPluginInterface: + """Interface for accessing semantic analyzer functionality in plugins.""" + + @abstractmethod + def fail(self, msg: str, ctx: Context) -> None: + raise NotImplementedError + + @abstractmethod + def named_type(self, name: str, args: List[Type]) -> Instance: + raise NotImplementedError + + @abstractmethod + def analyze_type(self, typ: Type) -> Type: + raise NotImplementedError + + @abstractmethod + def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type], + List[int], + List[Optional[str]]]]: + raise NotImplementedError + + +# A context for a hook that semantically analyzes an unbound type. +AnalyzeTypeContext = NamedTuple( + 'AnalyzeTypeContext', [ + ('type', UnboundType), # Type to analyze + ('context', Context), + ('api', AnalyzerPluginInterface)]) -# Create an Instance given full name of class and type arguments. -NamedInstanceCallback = Callable[[str, List[Type]], Type] +class CheckerPluginInterface: + """Interface for accessing type checker functionality in plugins.""" -# Some objects and callbacks that plugins can use to get information from the -# type checker or to report errors. -PluginContext = NamedTuple('PluginContext', [('named_instance', NamedInstanceCallback), - ('msg', MessageBuilder), - ('context', Context)]) + msg = None # type: MessageBuilder + @abstractmethod + def named_generic_type(self, name: str, args: List[Type]) -> Instance: + raise NotImplementedError -# A callback that infers the return type of a function with a special signature. + +# A context for a function hook that infers the return type of a function with +# a special signature. # -# A no-op callback would just return the inferred return type, but a useful callback -# at least sometimes can infer a more precise type. -FunctionHook = Callable[ - [ - List[List[Type]], # List of types caller provides for each formal argument - List[List[Expression]], # Actual argument expressions for each formal argument - Type, # Return type for call inferred using the regular signature - NamedInstanceCallback # Callable for constructing a named instance type - ], - Type # Return type inferred by the callback -] - -# A callback that may infer a better signature for a method. Note that argument types aren't -# available yet. If you need them, you have to use a MethodHook instead. -MethodSignatureHook = Callable[ - [ - Type, # Base object type - List[List[Expression]], # Actual argument expressions for each formal argument - CallableType, # Original signature of the method - NamedInstanceCallback # Callable for constructing a named instance type - ], - CallableType # Potentially more precise signature inferred for the method -] - -# A callback that infers the return type of a method with a special signature. +# A no-op callback would just return the inferred return type, but a useful +# callback at least sometimes can infer a more precise type. +FunctionContext = NamedTuple( + 'FunctionContext', [ + ('arg_types', List[List[Type]]), # List of actual caller types for each formal argument + ('default_return_type', Type), # Return type inferred from signature + ('args', List[List[Expression]]), # Actual expressions for each formal argument + ('context', Context), + ('api', CheckerPluginInterface)]) + +# A context for a method signature hook that infers a better signature for a +# method. Note that argument types aren't available yet. If you need them, +# you have to use a method hook instead. +MethodSigContext = NamedTuple( + 'MethodSigContext', [ + ('type', Type), # Base object type for method call + ('args', List[List[Expression]]), # Actual expressions for each formal argument + ('default_signature', CallableType), # Original signature of the method + ('context', Context), + ('api', CheckerPluginInterface)]) + +# A context for a method hook that infers the return type of a method with a +# special signature. # -# This is pretty similar to FunctionHook. -MethodHook = Callable[ - [ - Type, # Base object type - List[List[Type]], # List of types caller provides for each formal argument - List[List[Expression]], # Actual argument expressions for each formal argument - Type, # Return type for call inferred using the regular signature - PluginContext # Access to type checking context - ], - Type # Return type inferred by the callback -] +# This is very similar to FunctionContext (only differences are documented). +MethodContext = NamedTuple( + 'MethodContext', [ + ('type', Type), # Base object type for method call + ('arg_types', List[List[Type]]), + ('default_return_type', Type), + ('args', List[List[Expression]]), + ('context', Context), + ('api', CheckerPluginInterface)]) + +# A context for an attribute type hook that infers the type of an attribute. +AttributeContext = NamedTuple( + 'AttributeContext', [ + ('type', Type), # Type of object with attribute + ('default_attr_type', Type), # Original attribute type + ('context', Context), + ('api', CheckerPluginInterface)]) class Plugin: @@ -67,18 +106,33 @@ class Plugin: All get_ methods are treated as pure functions (you should assume that results might be cached). + + Look at the comments of various *Context objects for descriptions of + various hooks. """ - def __init__(self, python_version: Tuple[int, int]) -> None: - self.python_version = python_version + def __init__(self, options: Options) -> None: + self.options = options + self.python_version = options.python_version + + def get_type_analyze_hook(self, fullname: str + ) -> Optional[Callable[[AnalyzeTypeContext], Type]]: + return None + + def get_function_hook(self, fullname: str + ) -> Optional[Callable[[FunctionContext], Type]]: + return None - def get_function_hook(self, fullname: str) -> Optional[FunctionHook]: + def get_method_signature_hook(self, fullname: str + ) -> Optional[Callable[[MethodSigContext], CallableType]]: return None - def get_method_signature_hook(self, fullname: str) -> Optional[MethodSignatureHook]: + def get_method_hook(self, fullname: str + ) -> Optional[Callable[[MethodContext], Type]]: return None - def get_method_hook(self, fullname: str) -> Optional[MethodHook]: + def get_attribute_hook(self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: return None # TODO: metaclass / class decorator hook @@ -99,23 +153,34 @@ class ChainedPlugin(Plugin): # TODO: Support caching of lookup results (through a LRU cache, for example). - def __init__(self, python_version: Tuple[int, int], plugins: List[Plugin]) -> None: + def __init__(self, options: Options, plugins: List[Plugin]) -> None: """Initialize chained plugin. Assume that the child plugins aren't mutated (results may be cached). """ - super().__init__(python_version) + super().__init__(options) self._plugins = plugins - def get_function_hook(self, fullname: str) -> Optional[FunctionHook]: + def get_type_analyze_hook(self, fullname: str + ) -> Optional[Callable[[AnalyzeTypeContext], Type]]: + return self._find_hook(lambda plugin: plugin.get_type_analyze_hook(fullname)) + + def get_function_hook(self, fullname: str + ) -> Optional[Callable[[FunctionContext], Type]]: return self._find_hook(lambda plugin: plugin.get_function_hook(fullname)) - def get_method_signature_hook(self, fullname: str) -> Optional[MethodSignatureHook]: + def get_method_signature_hook(self, fullname: str + ) -> Optional[Callable[[MethodSigContext], CallableType]]: return self._find_hook(lambda plugin: plugin.get_method_signature_hook(fullname)) - def get_method_hook(self, fullname: str) -> Optional[MethodHook]: + def get_method_hook(self, fullname: str + ) -> Optional[Callable[[MethodContext], Type]]: return self._find_hook(lambda plugin: plugin.get_method_hook(fullname)) + def get_attribute_hook(self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: + return self._find_hook(lambda plugin: plugin.get_attribute_hook(fullname)) + def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]: for plugin in self._plugins: hook = lookup(plugin) @@ -127,19 +192,22 @@ def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]: class DefaultPlugin(Plugin): """Type checker plugin that is enabled by default.""" - def get_function_hook(self, fullname: str) -> Optional[FunctionHook]: + def get_function_hook(self, fullname: str + ) -> Optional[Callable[[FunctionContext], Type]]: if fullname == 'contextlib.contextmanager': return contextmanager_callback elif fullname == 'builtins.open' and self.python_version[0] == 3: return open_callback return None - def get_method_signature_hook(self, fullname: str) -> Optional[MethodSignatureHook]: + def get_method_signature_hook(self, fullname: str + ) -> Optional[Callable[[MethodSigContext], CallableType]]: if fullname == 'typing.Mapping.get': return typed_dict_get_signature_callback return None - def get_method_hook(self, fullname: str) -> Optional[MethodHook]: + def get_method_hook(self, fullname: str + ) -> Optional[Callable[[MethodContext], Type]]: if fullname == 'typing.Mapping.get': return typed_dict_get_callback elif fullname == 'builtins.int.__pow__': @@ -147,67 +215,57 @@ def get_method_hook(self, fullname: str) -> Optional[MethodHook]: return None -def open_callback( - arg_types: List[List[Type]], - args: List[List[Expression]], - inferred_return_type: Type, - named_generic_type: Callable[[str, List[Type]], Type]) -> Type: +def open_callback(ctx: FunctionContext) -> Type: """Infer a better return type for 'open'. Infer TextIO or BinaryIO as the return value if the mode argument is not given or is a literal. """ mode = None - if not arg_types or len(arg_types[1]) != 1: + if not ctx.arg_types or len(ctx.arg_types[1]) != 1: mode = 'r' - elif isinstance(args[1][0], StrExpr): - mode = args[1][0].value + elif isinstance(ctx.args[1][0], StrExpr): + mode = ctx.args[1][0].value if mode is not None: - assert isinstance(inferred_return_type, Instance) + assert isinstance(ctx.default_return_type, Instance) if 'b' in mode: - return named_generic_type('typing.BinaryIO', []) + return ctx.api.named_generic_type('typing.BinaryIO', []) else: - return named_generic_type('typing.TextIO', []) - return inferred_return_type + return ctx.api.named_generic_type('typing.TextIO', []) + return ctx.default_return_type -def contextmanager_callback( - arg_types: List[List[Type]], - args: List[List[Expression]], - inferred_return_type: Type, - named_generic_type: Callable[[str, List[Type]], Type]) -> Type: +def contextmanager_callback(ctx: FunctionContext) -> Type: """Infer a better return type for 'contextlib.contextmanager'.""" # Be defensive, just in case. - if arg_types and len(arg_types[0]) == 1: - arg_type = arg_types[0][0] - if isinstance(arg_type, CallableType) and isinstance(inferred_return_type, CallableType): + if ctx.arg_types and len(ctx.arg_types[0]) == 1: + arg_type = ctx.arg_types[0][0] + if (isinstance(arg_type, CallableType) + and isinstance(ctx.default_return_type, CallableType)): # The stub signature doesn't preserve information about arguments so # add them back here. - return inferred_return_type.copy_modified( + return ctx.default_return_type.copy_modified( arg_types=arg_type.arg_types, arg_kinds=arg_type.arg_kinds, arg_names=arg_type.arg_names) - return inferred_return_type + return ctx.default_return_type -def typed_dict_get_signature_callback( - object_type: Type, - args: List[List[Expression]], - signature: CallableType, - named_generic_type: Callable[[str, List[Type]], Type]) -> CallableType: +def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType: """Try to infer a better signature type for TypedDict.get. This is used to get better type context for the second argument that depends on a TypedDict value type. """ - if (isinstance(object_type, TypedDictType) - and len(args) == 2 - and len(args[0]) == 1 - and isinstance(args[0][0], StrExpr) + signature = ctx.default_signature + if (isinstance(ctx.type, TypedDictType) + and len(ctx.args) == 2 + and len(ctx.args[0]) == 1 + and isinstance(ctx.args[0][0], StrExpr) and len(signature.arg_types) == 2 and len(signature.variables) == 1): - key = args[0][0].value - value_type = object_type.items.get(key) + key = ctx.args[0][0].value + value_type = ctx.type.items.get(key) if value_type: # Tweak the signature to include the value type as context. It's # only needed for type inference since there's a union with a type @@ -219,49 +277,39 @@ def typed_dict_get_signature_callback( return signature -def typed_dict_get_callback( - object_type: Type, - arg_types: List[List[Type]], - args: List[List[Expression]], - inferred_return_type: Type, - context: PluginContext) -> Type: +def typed_dict_get_callback(ctx: MethodContext) -> Type: """Infer a precise return type for TypedDict.get with literal first argument.""" - if (isinstance(object_type, TypedDictType) - and len(arg_types) >= 1 - and len(arg_types[0]) == 1): - if isinstance(args[0][0], StrExpr): - key = args[0][0].value - value_type = object_type.items.get(key) + if (isinstance(ctx.type, TypedDictType) + and len(ctx.arg_types) >= 1 + and len(ctx.arg_types[0]) == 1): + if isinstance(ctx.args[0][0], StrExpr): + key = ctx.args[0][0].value + value_type = ctx.type.items.get(key) if value_type: - if len(arg_types) == 1: + if len(ctx.arg_types) == 1: return UnionType.make_simplified_union([value_type, NoneTyp()]) - elif len(arg_types) == 2 and len(arg_types[1]) == 1: - return UnionType.make_simplified_union([value_type, arg_types[1][0]]) + elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1: + return UnionType.make_simplified_union([value_type, ctx.arg_types[1][0]]) else: - context.msg.typeddict_item_name_not_found(object_type, key, context.context) + ctx.api.msg.typeddict_item_name_not_found(ctx.type, key, ctx.context) return AnyType() - return inferred_return_type + return ctx.default_return_type -def int_pow_callback( - object_type: Type, - arg_types: List[List[Type]], - args: List[List[Expression]], - inferred_return_type: Type, - context: PluginContext) -> Type: +def int_pow_callback(ctx: MethodContext) -> Type: """Infer a more precise return type for int.__pow__.""" - if (len(arg_types) == 1 - and len(arg_types[0]) == 1): - arg = args[0][0] + if (len(ctx.arg_types) == 1 + and len(ctx.arg_types[0]) == 1): + arg = ctx.args[0][0] if isinstance(arg, IntExpr): exponent = arg.value elif isinstance(arg, UnaryExpr) and arg.op == '-' and isinstance(arg.expr, IntExpr): exponent = -arg.expr.value else: # Right operand not an int literal or a negated literal -- give up. - return inferred_return_type + return ctx.default_return_type if exponent >= 0: - return context.named_instance('builtins.int', []) + return ctx.api.named_generic_type('builtins.int', []) else: - return context.named_instance('builtins.float', []) - return inferred_return_type + return ctx.api.named_generic_type('builtins.float', []) + return ctx.default_return_type diff --git a/mypy/semanal.py b/mypy/semanal.py index 14630d964b29..1084e90ef55e 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -90,6 +90,7 @@ from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.sametypes import is_same_type from mypy.options import Options +from mypy.plugin import Plugin from mypy import join @@ -230,11 +231,13 @@ class SemanticAnalyzer(NodeVisitor): is_stub_file = False # Are we analyzing a stub file? imports = None # type: Set[str] # Imported modules (during phase 2 analysis) errors = None # type: Errors # Keeps track of generated errors + plugin = None # type: Plugin # Mypy plugin for special casing of library features def __init__(self, modules: Dict[str, MypyFile], missing_modules: Set[str], - lib_path: List[str], errors: Errors) -> None: + lib_path: List[str], errors: Errors, + plugin: Plugin) -> None: """Construct semantic analyzer. Use lib_path to search for modules, and report analysis errors @@ -256,6 +259,7 @@ def __init__(self, self.postpone_nested_functions_stack = [FUNCTION_BOTH_PHASES] self.postponed_functions_stack = [] self.all_exports = set() # type: Set[str] + self.plugin = plugin def visit_file(self, file_node: MypyFile, fnam: str, options: Options) -> None: self.options = options @@ -1497,6 +1501,7 @@ def type_analyzer(self, *, self.lookup_fully_qualified, tvar_scope, self.fail, + self.plugin, aliasing=aliasing, allow_tuple_literal=allow_tuple_literal, allow_unnormalized=self.is_stub_file) @@ -1534,7 +1539,9 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: self.lookup_qualified, self.lookup_fully_qualified, self.tvar_scope, - self.fail, allow_unnormalized=True) + self.fail, + self.plugin, + allow_unnormalized=True) if res and (not isinstance(res, Instance) or res.args): # TODO: What if this gets reassigned? name = s.lvalues[0] @@ -3099,7 +3106,9 @@ def visit_index_expr(self, expr: IndexExpr) -> None: self.lookup_qualified, self.lookup_fully_qualified, self.tvar_scope, - self.fail, allow_unnormalized=self.is_stub_file) + self.fail, + self.plugin, + allow_unnormalized=self.is_stub_file) expr.analyzed = TypeAliasExpr(res, fallback=self.alias_fallback(res), in_runtime=True) elif refers_to_class_or_function(expr.base): diff --git a/mypy/test/testgraph.py b/mypy/test/testgraph.py index d168ad53e236..dbbe4872aa75 100644 --- a/mypy/test/testgraph.py +++ b/mypy/test/testgraph.py @@ -10,6 +10,7 @@ from mypy.report import Reports from mypy.plugin import Plugin from mypy import defaults +from mypy.errors import Errors class GraphSuite(Suite): @@ -36,15 +37,18 @@ def test_scc(self) -> None: frozenset({'D'})}) def _make_manager(self) -> BuildManager: + errors = Errors() + options = Options() manager = BuildManager( data_dir='', lib_path=[], ignore_prefix='', source_set=BuildSourceSet([]), reports=Reports('', {}), - options=Options(), + options=options, version_id=__version__, - plugin=Plugin(defaults.PYTHON3_VERSION), + plugin=Plugin(options), + errors=errors, ) return manager diff --git a/mypy/typeanal.py b/mypy/typeanal.py index d36ecb98de6a..4dc6dc1bd02f 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -24,6 +24,7 @@ from mypy.sametypes import is_same_type from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.subtypes import is_subtype +from mypy.plugin import Plugin, AnalyzerPluginInterface, AnalyzeTypeContext from mypy import nodes from mypy import experiments @@ -54,6 +55,7 @@ def analyze_type_alias(node: Expression, lookup_fqn_func: Callable[[str], SymbolTableNode], tvar_scope: TypeVarScope, fail_func: Callable[[str, Context], None], + plugin: Plugin, allow_unnormalized: bool = False) -> Optional[Type]: """Return type if node is valid as a type alias rvalue. @@ -96,8 +98,8 @@ def analyze_type_alias(node: Expression, except TypeTranslationError: fail_func('Invalid type alias', node) return None - analyzer = TypeAnalyser(lookup_func, lookup_fqn_func, tvar_scope, fail_func, aliasing=True, - allow_unnormalized=allow_unnormalized) + analyzer = TypeAnalyser(lookup_func, lookup_fqn_func, tvar_scope, fail_func, plugin, + aliasing=True, allow_unnormalized=allow_unnormalized) return type.accept(analyzer) @@ -109,7 +111,7 @@ def no_subscript_builtin_alias(name: str, propose_alt: bool = True) -> str: return msg -class TypeAnalyser(SyntheticTypeVisitor[Type]): +class TypeAnalyser(SyntheticTypeVisitor[Type], AnalyzerPluginInterface): """Semantic analyzer for types (semantic analysis pass 2). Converts unbound types into bound types. @@ -119,19 +121,21 @@ def __init__(self, lookup_func: Callable[[str, Context], SymbolTableNode], lookup_fqn_func: Callable[[str], SymbolTableNode], tvar_scope: TypeVarScope, - fail_func: Callable[[str, Context], None], *, + fail_func: Callable[[str, Context], None], + plugin: Plugin, *, aliasing: bool = False, allow_tuple_literal: bool = False, allow_unnormalized: bool = False) -> None: self.lookup = lookup_func self.lookup_fqn_func = lookup_fqn_func - self.fail = fail_func + self.fail_func = fail_func self.tvar_scope = tvar_scope self.aliasing = aliasing self.allow_tuple_literal = allow_tuple_literal # Positive if we are analyzing arguments of another (outer) type self.nesting_level = 0 self.allow_unnormalized = allow_unnormalized + self.plugin = plugin def visit_unbound_type(self, t: UnboundType) -> Type: if t.optional: @@ -147,6 +151,9 @@ def visit_unbound_type(self, t: UnboundType) -> Type: self.fail('Internal error (node is None, kind={})'.format(sym.kind), t) return AnyType() fullname = sym.node.fullname() + hook = self.plugin.get_type_analyze_hook(fullname) + if hook: + return hook(AnalyzeTypeContext(t, t, self)) if (fullname in nongen_builtins and t.args and not sym.normalized and not self.allow_unnormalized): self.fail(no_subscript_builtin_alias(fullname), t) @@ -163,10 +170,10 @@ def visit_unbound_type(self, t: UnboundType) -> Type: elif fullname == 'typing.Tuple': if len(t.args) == 0 and not t.empty_tuple_index: # Bare 'Tuple' is same as 'tuple' - return self.builtin_type('builtins.tuple') + return self.named_type('builtins.tuple') if len(t.args) == 2 and isinstance(t.args[1], EllipsisType): # Tuple[T, ...] (uniform, variable-length tuple) - instance = self.builtin_type('builtins.tuple', [self.anal_type(t.args[0])]) + instance = self.named_type('builtins.tuple', [self.anal_type(t.args[0])]) instance.line = t.line return instance return self.tuple_type(self.anal_array(t.args)) @@ -341,7 +348,7 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: variables = self.bind_function_type_variables(t, t) ret = t.copy_modified(arg_types=self.anal_array(t.arg_types, nested=nested), ret_type=self.anal_type(t.ret_type, nested=nested), - fallback=t.fallback or self.builtin_type('builtins.function'), + fallback=t.fallback or self.named_type('builtins.function'), variables=self.anal_var_defs(variables)) return ret @@ -356,11 +363,11 @@ def visit_tuple_type(self, t: TupleType) -> Type: self.fail('At most one star type allowed in a tuple', t) if t.implicit: return TupleType([AnyType() for _ in t.items], - self.builtin_type('builtins.tuple'), + self.named_type('builtins.tuple'), t.line) else: return AnyType() - fallback = t.fallback if t.fallback else self.builtin_type('builtins.tuple', [AnyType()]) + fallback = t.fallback if t.fallback else self.named_type('builtins.tuple', [AnyType()]) return TupleType(self.anal_array(t.items), fallback, t.line) def visit_typeddict_type(self, t: TypedDictType) -> Type: @@ -387,7 +394,7 @@ def visit_type_type(self, t: TypeType) -> Type: return TypeType.make_normalized(self.anal_type(t.item), line=t.line) def analyze_callable_type(self, t: UnboundType) -> Type: - fallback = self.builtin_type('builtins.function') + fallback = self.named_type('builtins.function') if len(t.args) == 0: # Callable (bare). Treat as Callable[..., Any]. ret = CallableType([AnyType(), AnyType()], @@ -400,37 +407,10 @@ def analyze_callable_type(self, t: UnboundType) -> Type: ret_type = t.args[1] if isinstance(t.args[0], TypeList): # Callable[[ARG, ...], RET] (ordinary callable type) - args = [] # type: List[Type] - names = [] # type: List[str] - kinds = [] # type: List[int] - for arg in t.args[0].items: - if isinstance(arg, CallableArgument): - args.append(arg.typ) - names.append(arg.name) - if arg.constructor is None: - return AnyType() - found = self.lookup(arg.constructor, arg) - if found is None: - # Looking it up already put an error message in - return AnyType() - elif found.fullname not in ARG_KINDS_BY_CONSTRUCTOR: - self.fail('Invalid argument constructor "{}"'.format( - found.fullname), arg) - return AnyType() - else: - kind = ARG_KINDS_BY_CONSTRUCTOR[found.fullname] - kinds.append(kind) - if arg.name is not None and kind in {ARG_STAR, ARG_STAR2}: - self.fail("{} arguments should not have names".format( - arg.constructor), arg) - return AnyType() - else: - args.append(arg) - names.append(None) - kinds.append(ARG_POS) - - check_arg_names(names, [t] * len(args), self.fail, "Callable") - check_arg_kinds(kinds, [t] * len(args), self.fail) + analyzed_args = self.analyze_callable_args(t.args[0]) + if analyzed_args is None: + return AnyType() + args, kinds, names = analyzed_args ret = CallableType(args, kinds, names, @@ -453,6 +433,48 @@ def analyze_callable_type(self, t: UnboundType) -> Type: assert isinstance(ret, CallableType) return ret.accept(self) + def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type], + List[int], + List[Optional[str]]]]: + args = [] # type: List[Type] + kinds = [] # type: List[int] + names = [] # type: List[str] + for arg in arglist.items: + if isinstance(arg, CallableArgument): + args.append(arg.typ) + names.append(arg.name) + if arg.constructor is None: + return None + found = self.lookup(arg.constructor, arg) + if found is None: + # Looking it up already put an error message in + return None + elif found.fullname not in ARG_KINDS_BY_CONSTRUCTOR: + self.fail('Invalid argument constructor "{}"'.format( + found.fullname), arg) + return None + else: + kind = ARG_KINDS_BY_CONSTRUCTOR[found.fullname] + kinds.append(kind) + if arg.name is not None and kind in {ARG_STAR, ARG_STAR2}: + self.fail("{} arguments should not have names".format( + arg.constructor), arg) + return None + else: + args.append(arg) + kinds.append(ARG_POS) + names.append(None) + # Note that arglist below is only used for error context. + check_arg_names(names, [arglist] * len(args), self.fail, "Callable") + check_arg_kinds(kinds, [arglist] * len(args), self.fail) + return args, kinds, names + + def analyze_type(self, t: Type) -> Type: + return t.accept(self) + + def fail(self, msg: str, ctx: Context) -> None: + self.fail_func(msg, ctx) + @contextmanager def tvar_scope_frame(self) -> Iterator[None]: old_scope = self.tvar_scope @@ -532,13 +554,13 @@ def anal_var_defs(self, var_defs: List[TypeVarDef]) -> List[TypeVarDef]: vd.line)) return a - def builtin_type(self, fully_qualified_name: str, args: List[Type] = None) -> Instance: + def named_type(self, fully_qualified_name: str, args: List[Type] = None) -> Instance: node = self.lookup_fqn_func(fully_qualified_name) assert isinstance(node.node, TypeInfo) return Instance(node.node, args or []) def tuple_type(self, items: List[Type]) -> TupleType: - return TupleType(items, fallback=self.builtin_type('builtins.tuple', [AnyType()])) + return TupleType(items, fallback=self.named_type('builtins.tuple', [AnyType()])) class TypeAnalyserPass3(TypeVisitor[None]): diff --git a/mypy/types.py b/mypy/types.py index d8598554aeb4..1f45da1237ac 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -549,7 +549,7 @@ class CallableType(FunctionLike): is_var_arg = False # Is it a varargs function? derived from arg_kinds is_kw_arg = False ret_type = None # type: Type # Return value type - name = '' # type: Optional[str] # Name (may be None; for error messages) + name = '' # type: Optional[str] # Name (may be None; for error messages and plugins) definition = None # type: Optional[SymbolNode] # For error messages. May be None. # Type variables for a generic function variables = None # type: List[TypeVarDef] diff --git a/test-data/unit/check-custom-plugin.test b/test-data/unit/check-custom-plugin.test index 69d469a6f415..d1fdd9fbcfa9 100644 --- a/test-data/unit/check-custom-plugin.test +++ b/test-data/unit/check-custom-plugin.test @@ -83,3 +83,55 @@ f() plugins=/test-data/unit/plugins/badreturn2.py [out] tmp/mypy.ini:2: error: Return value of "plugin" must be a subclass of "mypy.plugin.Plugin" (in /test-data/unit/plugins/badreturn2.py) + +[case testAttributeTypeHookPlugin] +# flags: --config-file tmp/mypy.ini +from typing import Callable +from m import Signal +s: Signal[Callable[[int], None]] = Signal() +s(1) +s('') # E: Argument 1 has incompatible type "str"; expected "int" +[file m.py] +from typing import TypeVar, Generic, Callable +T = TypeVar('T', bound=Callable[..., None]) +class Signal(Generic[T]): + __call__: Callable[..., None] # This type is replaced by the plugin +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/attrhook.py + +[case testTypeAnalyzeHookPlugin] +# flags: --config-file tmp/mypy.ini +from typing import Callable +from mypy_extensions import DefaultArg +from m import Signal +s: Signal[[int, DefaultArg(str, 'x')]] = Signal() +reveal_type(s) # E: Revealed type is 'm.Signal[def (builtins.int, x: builtins.str =)]' +s.x # E: Signal[Callable[[int, str], None]] has no attribute "x" +ss: Signal[int, str] # E: Invalid "Signal" type (expected "Signal[[t, ...]]") +[file m.py] +from typing import TypeVar, Generic, Callable +T = TypeVar('T', bound=Callable[..., None]) +class Signal(Generic[T]): + __call__: Callable[..., None] +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/type_anal_hook.py +[builtins fixtures/dict.pyi] + +[case testFunctionPluginHookForReturnedCallable] +# flags: --config-file tmp/mypy.ini +from m import decorator1, decorator2 +@decorator1() +def f() -> None: pass +@decorator2() +def g() -> None: pass +reveal_type(f) # E: Revealed type is 'def (*Any, **Any) -> builtins.str' +reveal_type(g) # E: Revealed type is 'def (*Any, **Any) -> builtins.int' +[file m.py] +from typing import Callable +def decorator1() -> Callable[..., Callable[..., int]]: pass +def decorator2() -> Callable[..., Callable[..., int]]: pass +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/named_callable.py diff --git a/test-data/unit/plugins/attrhook.py b/test-data/unit/plugins/attrhook.py new file mode 100644 index 000000000000..d94a5d6ccf79 --- /dev/null +++ b/test-data/unit/plugins/attrhook.py @@ -0,0 +1,21 @@ +from typing import Optional, Callable + +from mypy.plugin import Plugin, AttributeContext +from mypy.types import Type, Instance + + +class AttrPlugin(Plugin): + def get_attribute_hook(self, fullname: str) -> Optional[Callable[[AttributeContext], Type]]: + if fullname == 'm.Signal.__call__': + return signal_call_callback + return None + + +def signal_call_callback(ctx: AttributeContext) -> Type: + if isinstance(ctx.type, Instance) and ctx.type.type.fullname() == 'm.Signal': + return ctx.type.args[0] + return ctx.inferred_attr_type + + +def plugin(version): + return AttrPlugin diff --git a/test-data/unit/plugins/fnplugin.py b/test-data/unit/plugins/fnplugin.py index d5027219a09f..513279213b7d 100644 --- a/test-data/unit/plugins/fnplugin.py +++ b/test-data/unit/plugins/fnplugin.py @@ -6,8 +6,8 @@ def get_function_hook(self, fullname): return my_hook return None -def my_hook(arg_types, args, inferred_return_type, named_generic_type): - return named_generic_type('builtins.int', []) +def my_hook(ctx): + return ctx.api.named_generic_type('builtins.int', []) def plugin(version): return MyPlugin diff --git a/test-data/unit/plugins/named_callable.py b/test-data/unit/plugins/named_callable.py new file mode 100644 index 000000000000..e40d181d2bad --- /dev/null +++ b/test-data/unit/plugins/named_callable.py @@ -0,0 +1,28 @@ +from mypy.plugin import Plugin +from mypy.types import CallableType + + +class MyPlugin(Plugin): + def get_function_hook(self, fullname): + if fullname == 'm.decorator1': + return decorator_call_hook + if fullname == 'm._decorated': # This is a dummy name generated by the plugin + return decorate_hook + return None + + +def decorator_call_hook(ctx): + if isinstance(ctx.default_return_type, CallableType): + return ctx.default_return_type.copy_modified(name='m._decorated') + return ctx.default_return_type + + +def decorate_hook(ctx): + if isinstance(ctx.default_return_type, CallableType): + return ctx.default_return_type.copy_modified( + ret_type=ctx.api.named_generic_type('builtins.str', [])) + return ctx.default_return_type + + +def plugin(version): + return MyPlugin diff --git a/test-data/unit/plugins/plugin2.py b/test-data/unit/plugins/plugin2.py index 1584871fae1d..b530a62d23aa 100644 --- a/test-data/unit/plugins/plugin2.py +++ b/test-data/unit/plugins/plugin2.py @@ -6,8 +6,8 @@ def get_function_hook(self, fullname): return str_hook return None -def str_hook(arg_types, args, inferred_return_type, named_generic_type): - return named_generic_type('builtins.str', []) +def str_hook(ctx): + return ctx.api.named_generic_type('builtins.str', []) def plugin(version): return Plugin2 diff --git a/test-data/unit/plugins/type_anal_hook.py b/test-data/unit/plugins/type_anal_hook.py new file mode 100644 index 000000000000..0e7a0ee9409c --- /dev/null +++ b/test-data/unit/plugins/type_anal_hook.py @@ -0,0 +1,37 @@ +from typing import Optional, Callable + +from mypy.plugin import Plugin, AnalyzeTypeContext +from mypy.types import Type, UnboundType, TypeList, AnyType, NoneTyp, CallableType + + +class TypeAnalyzePlugin(Plugin): + def get_type_analyze_hook(self, fullname: str + ) -> Optional[Callable[[AnalyzeTypeContext], Type]]: + if fullname == 'm.Signal': + return signal_type_analyze_callback + return None + + +def signal_type_analyze_callback(ctx: AnalyzeTypeContext) -> Type: + if (len(ctx.type.args) != 1 + or not isinstance(ctx.type.args[0], TypeList)): + ctx.api.fail('Invalid "Signal" type (expected "Signal[[t, ...]]")', ctx.context) + return AnyType() + + args = ctx.type.args[0] + assert isinstance(args, TypeList) + analyzed = ctx.api.analyze_callable_args(args) + if analyzed is None: + return AnyType() # Error generated elsewhere + arg_types, arg_kinds, arg_names = analyzed + arg_types = [ctx.api.analyze_type(arg) for arg in arg_types] + type_arg = CallableType(arg_types, + arg_kinds, + arg_names, + NoneTyp(), + ctx.api.named_type('builtins.function', [])) + return ctx.api.named_type('m.Signal', [type_arg]) + + +def plugin(version): + return TypeAnalyzePlugin