diff --git a/mypy/build.py b/mypy/build.py index 41434adf6c79..58089af97979 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -42,7 +42,7 @@ from mypy.stats import dump_type_stats from mypy.types import Type from mypy.version import __version__ -from mypy.plugin import DefaultPlugin +from mypy.plugin import Plugin, DefaultPlugin, ChainedPlugin # We need to know the location of this file to load data, but @@ -183,7 +183,9 @@ def build(sources: List[BuildSource], reports=reports, options=options, version_id=__version__, - ) + plugin=DefaultPlugin(options.python_version)) + + manager.plugin = load_custom_plugins(manager.plugin, options, manager.errors) try: graph = dispatch(sources, manager) @@ -333,6 +335,67 @@ def import_priority(imp: ImportBase, toplevel_priority: int) -> int: return toplevel_priority +def load_custom_plugins(default_plugin: Plugin, options: Options, errors: Errors) -> Plugin: + """Load custom plugins if any are configured. + + Return a plugin that chains all custom plugins (if any) and falls + back to default_plugin. + """ + + def plugin_error(message: str) -> None: + errors.report(0, 0, message) + errors.raise_error() + + custom_plugins = [] + for plugin_path in options.plugins: + if options.config_file: + # Plugin paths are relative to the config file location. + plugin_path = os.path.join(os.path.dirname(options.config_file), plugin_path) + errors.set_file(plugin_path, None) + + if not os.path.isfile(plugin_path): + plugin_error("Can't find plugin") + plugin_dir = os.path.dirname(plugin_path) + fnam = os.path.basename(plugin_path) + if not fnam.endswith('.py'): + plugin_error("Plugin must have .py extension") + module_name = fnam[:-3] + import importlib + sys.path.insert(0, plugin_dir) + try: + m = importlib.import_module(module_name) + except Exception: + print('Error importing plugin {}\n'.format(plugin_path)) + raise # Propagate to display traceback + finally: + assert sys.path[0] == plugin_dir + del sys.path[0] + if not hasattr(m, 'plugin'): + plugin_error('Plugin does not define entry point function "plugin"') + try: + plugin_type = getattr(m, 'plugin')(__version__) + except Exception: + print('Error calling the plugin(version) entry point of {}\n'.format(plugin_path)) + raise # Propagate to display traceback + if not isinstance(plugin_type, type): + plugin_error( + 'Type object expected as the return value of "plugin" (got {!r})'.format( + plugin_type)) + if not issubclass(plugin_type, Plugin): + plugin_error( + 'Return value of "plugin" must be a subclass of "mypy.plugin.Plugin"') + try: + custom_plugins.append(plugin_type(options.python_version)) + except Exception: + print('Error constructing plugin instance of {}\n'.format(plugin_type.__name__)) + raise # Propagate to display traceback + if not custom_plugins: + return default_plugin + else: + # Custom plugins take precendence over built-in plugins. + return ChainedPlugin(options.python_version, custom_plugins + [default_plugin]) + + # TODO: Get rid of all_types. It's not used except for one log message. # Maybe we could instead publish a map from module ID to its type_map. class BuildManager: @@ -356,6 +419,7 @@ class BuildManager: missing_modules: Set of modules that could not be imported encountered so far stale_modules: Set of modules that needed to be rechecked version_id: The current mypy version (based on commit id when possible) + plugin: Active mypy plugin(s) """ def __init__(self, data_dir: str, @@ -364,7 +428,8 @@ def __init__(self, data_dir: str, source_set: BuildSourceSet, reports: Reports, options: Options, - version_id: str) -> None: + version_id: str, + plugin: Plugin) -> None: self.start_time = time.time() self.data_dir = data_dir self.errors = Errors(options.show_error_context, options.show_column_numbers) @@ -384,6 +449,7 @@ def __init__(self, data_dir: str, self.indirection_detector = TypeIndirectionVisitor() self.stale_modules = set() # type: Set[str] self.rechecked_modules = set() # type: Set[str] + self.plugin = plugin def maybe_swap_for_shadow_path(self, path: str) -> str: if (self.options.shadow_file and @@ -1506,9 +1572,8 @@ def type_check_first_pass(self) -> None: if self.options.semantic_analysis_only: return with self.wrap_context(): - plugin = DefaultPlugin(self.options.python_version) self.type_checker = TypeChecker(manager.errors, manager.modules, self.options, - self.tree, self.xpath, plugin) + self.tree, self.xpath, manager.plugin) self.type_checker.check_first_pass() def type_check_second_pass(self) -> bool: diff --git a/mypy/main.py b/mypy/main.py index 2d0fe9ec7607..7d344530461c 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -372,7 +372,7 @@ def disallow_any_argument_type(raw_options: str) -> List[str]: parser.parse_args(args, dummy) config_file = dummy.config_file if config_file is not None and not os.path.exists(config_file): - parser.error("Cannot file config file '%s'" % config_file) + parser.error("Cannot find config file '%s'" % config_file) # Parse config file first, so command line can override. options = Options() @@ -605,6 +605,7 @@ def get_init_file(dir: str) -> Optional[str]: # These two are for backwards compatibility 'silent_imports': bool, 'almost_silent': bool, + 'plugins': lambda s: [p.strip() for p in s.split(',')], } SHARED_CONFIG_FILES = ('setup.cfg',) diff --git a/mypy/options.py b/mypy/options.py index 69f99cce9501..fac8fe6d4459 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -113,6 +113,9 @@ def __init__(self) -> None: self.debug_cache = False self.quick_and_dirty = False + # Paths of user plugins + self.plugins = [] # type: List[str] + # Per-module options (raw) self.per_module_options = {} # type: Dict[Pattern[str], Dict[str, object]] diff --git a/mypy/plugin.py b/mypy/plugin.py index 5015f7b4c940..7acd4d0b29a5 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Tuple, Optional, NamedTuple +from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context from mypy.types import ( @@ -60,7 +60,7 @@ class Plugin: - """Base class of type checker plugins. + """Base class of all type checker plugins. This defines a no-op plugin. Subclasses can override some methods to provide some actual functionality. @@ -69,8 +69,6 @@ class Plugin: results might be cached). """ - # TODO: Way of chaining multiple plugins - def __init__(self, python_version: Tuple[int, int]) -> None: self.python_version = python_version @@ -86,6 +84,46 @@ def get_method_hook(self, fullname: str) -> Optional[MethodHook]: # TODO: metaclass / class decorator hook +T = TypeVar('T') + + +class ChainedPlugin(Plugin): + """A plugin that represents a sequence of chained plugins. + + Each lookup method returns the hook for the first plugin that + reports a match. + + This class should not be subclassed -- use Plugin as the base class + for all plugins. + """ + + # TODO: Support caching of lookup results (through a LRU cache, for example). + + def __init__(self, python_version: Tuple[int, int], plugins: List[Plugin]) -> None: + """Initialize chained plugin. + + Assume that the child plugins aren't mutated (results may be cached). + """ + super().__init__(python_version) + self._plugins = plugins + + def get_function_hook(self, fullname: str) -> Optional[FunctionHook]: + return self._find_hook(lambda plugin: plugin.get_function_hook(fullname)) + + def get_method_signature_hook(self, fullname: str) -> Optional[MethodSignatureHook]: + return self._find_hook(lambda plugin: plugin.get_method_signature_hook(fullname)) + + def get_method_hook(self, fullname: str) -> Optional[MethodHook]: + return self._find_hook(lambda plugin: plugin.get_method_hook(fullname)) + + def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]: + for plugin in self._plugins: + hook = lookup(plugin) + if hook: + return hook + return None + + class DefaultPlugin(Plugin): """Type checker plugin that is enabled by default.""" diff --git a/mypy/test/data.py b/mypy/test/data.py index ccee92eac276..09fe931d0c62 100644 --- a/mypy/test/data.py +++ b/mypy/test/data.py @@ -13,6 +13,9 @@ from mypy.myunit import TestCase, SkipTestCaseException +root_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), '..', '..')) + + def parse_test_cases( path: str, perform: Optional[Callable[['DataDrivenTestCase'], None]], @@ -62,7 +65,9 @@ def parse_test_cases( # Record an extra file needed for the test case. arg = p[i].arg assert arg is not None - file_entry = (join(base_path, arg), '\n'.join(p[i].data)) + contents = '\n'.join(p[i].data) + contents = expand_variables(contents) + file_entry = (join(base_path, arg), contents) if p[i].id == 'file': files.append(file_entry) elif p[i].id == 'outfile': @@ -119,13 +124,15 @@ def parse_test_cases( deleted_paths.setdefault(num, set()).add(full) elif p[i].id == 'out' or p[i].id == 'out1': tcout = p[i].data - if native_sep and os.path.sep == '\\': + tcout = [expand_variables(line) for line in tcout] + if os.path.sep == '\\': tcout = [fix_win_path(line) for line in tcout] ok = True elif re.match(r'out[0-9]*$', p[i].id): passnum = int(p[i].id[3:]) assert passnum > 1 output = p[i].data + output = [expand_variables(line) for line in output] if native_sep and os.path.sep == '\\': output = [fix_win_path(line) for line in output] tcout2[passnum] = output @@ -415,6 +422,10 @@ def expand_includes(a: List[str], base_path: str) -> List[str]: return res +def expand_variables(s: str) -> str: + return s.replace('', root_dir) + + def expand_errors(input: List[str], output: List[str], fnam: str) -> None: """Transform comments such as '# E: message' or '# E:3: message' in input. @@ -445,16 +456,17 @@ def expand_errors(input: List[str], output: List[str], fnam: str) -> None: def fix_win_path(line: str) -> str: - r"""Changes paths to Windows paths in error messages. + r"""Changes Windows paths to Linux paths in error messages. - E.g. foo/bar.py -> foo\bar.py. + E.g. foo\bar.py -> foo/bar.py. """ + line = line.replace(root_dir, root_dir.replace('\\', '/')) m = re.match(r'^([\S/]+):(\d+:)?(\s+.*)', line) if not m: return line else: filename, lineno, message = m.groups() - return '{}:{}{}'.format(filename.replace('/', '\\'), + return '{}:{}{}'.format(filename.replace('\\', '/'), lineno or '', message) diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 91a818ac0f01..b01424b8a053 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -76,6 +76,7 @@ 'check-classvar.test', 'check-enum.test', 'check-incomplete-fixture.test', + 'check-custom-plugin.test', ] @@ -261,7 +262,8 @@ def find_error_paths(self, a: List[str]) -> Set[str]: for line in a: m = re.match(r'([^\s:]+):\d+: error:', line) if m: - p = m.group(1).replace('/', os.path.sep) + # Normalize to Linux paths. + p = m.group(1).replace(os.path.sep, '/') hits.add(p) return hits diff --git a/mypy/test/testcmdline.py b/mypy/test/testcmdline.py index 8e56b42bc766..06009f107b14 100644 --- a/mypy/test/testcmdline.py +++ b/mypy/test/testcmdline.py @@ -15,7 +15,7 @@ from mypy.test.config import test_data_prefix, test_temp_dir from mypy.test.data import fix_cobertura_filename from mypy.test.data import parse_test_cases, DataDrivenTestCase -from mypy.test.helpers import assert_string_arrays_equal +from mypy.test.helpers import assert_string_arrays_equal, normalize_error_messages from mypy.version import __version__, base_version # Path to Python 3 interpreter @@ -71,10 +71,12 @@ def test_python_evaluation(testcase: DataDrivenTestCase) -> None: os.path.abspath(test_temp_dir)) if testcase.native_sep and os.path.sep == '\\': normalized_output = [fix_cobertura_filename(line) for line in normalized_output] + normalized_output = normalize_error_messages(normalized_output) assert_string_arrays_equal(expected_content.splitlines(), normalized_output, 'Output file {} did not match its expected output'.format( path)) else: + out = normalize_error_messages(out) assert_string_arrays_equal(testcase.output, out, 'Invalid output ({}, line {})'.format( testcase.file, testcase.line)) diff --git a/mypy/test/testgraph.py b/mypy/test/testgraph.py index 7a9062914f89..d168ad53e236 100644 --- a/mypy/test/testgraph.py +++ b/mypy/test/testgraph.py @@ -8,6 +8,8 @@ from mypy.version import __version__ from mypy.options import Options from mypy.report import Reports +from mypy.plugin import Plugin +from mypy import defaults class GraphSuite(Suite): @@ -42,6 +44,7 @@ def _make_manager(self) -> BuildManager: reports=Reports('', {}), options=Options(), version_id=__version__, + plugin=Plugin(defaults.PYTHON3_VERSION), ) return manager diff --git a/mypy/test/testsemanal.py b/mypy/test/testsemanal.py index 99c0078e9196..6d7f2ddb24bb 100644 --- a/mypy/test/testsemanal.py +++ b/mypy/test/testsemanal.py @@ -89,6 +89,7 @@ def test_semanal(testcase: DataDrivenTestCase) -> None: a += str(f).split('\n') except CompileError as e: a = e.messages + a = normalize_error_messages(a) assert_string_arrays_equal( testcase.output, a, 'Invalid semantic analyzer output ({}, line {})'.format(testcase.file, diff --git a/mypy/test/testtransform.py b/mypy/test/testtransform.py index 1dac3081efbd..0dcdd1d0c649 100644 --- a/mypy/test/testtransform.py +++ b/mypy/test/testtransform.py @@ -7,7 +7,9 @@ from mypy import build from mypy.build import BuildSource from mypy.myunit import Suite -from mypy.test.helpers import assert_string_arrays_equal, testfile_pyversion +from mypy.test.helpers import ( + assert_string_arrays_equal, testfile_pyversion, normalize_error_messages +) from mypy.test.data import parse_test_cases, DataDrivenTestCase from mypy.test.config import test_data_prefix, test_temp_dir from mypy.errors import CompileError @@ -73,6 +75,7 @@ def test_transform(testcase: DataDrivenTestCase) -> None: a += str(f).split('\n') except CompileError as e: a = e.messages + a = normalize_error_messages(a) assert_string_arrays_equal( testcase.output, a, 'Invalid semantic analyzer output ({}, line {})'.format(testcase.file, diff --git a/test-data/unit/check-custom-plugin.test b/test-data/unit/check-custom-plugin.test new file mode 100644 index 000000000000..30b00a4b3a62 --- /dev/null +++ b/test-data/unit/check-custom-plugin.test @@ -0,0 +1,71 @@ +-- Test cases for user-defined plugins +-- +-- Note: Plugins used by tests live under test-data/unit/plugins. Defining +-- plugin files in test cases does not work reliably. + +[case testFunctionPlugin] +# flags: --config-file tmp/mypy.ini +def f() -> str: ... +reveal_type(f()) # E: Revealed type is 'builtins.int' +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/fnplugin.py + +[case testTwoPlugins] +# flags: --config-file tmp/mypy.ini +def f(): ... +def g(): ... +def h(): ... +reveal_type(f()) # E: Revealed type is 'builtins.int' +reveal_type(g()) # E: Revealed type is 'builtins.str' +reveal_type(h()) # E: Revealed type is 'Any' +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/fnplugin.py, + /test-data/unit/plugins/plugin2.py + +[case testMissingPlugin] +# flags: --config-file tmp/mypy.ini +[file mypy.ini] +[[mypy] +plugins=missing.py +[out] +tmp/missing.py:0: error: Can't find plugin +--' (work around syntax highlighting) + +[case testInvalidPluginExtension] +# flags: --config-file tmp/mypy.ini +[file mypy.ini] +[[mypy] +plugins=badext.pyi +[file badext.pyi] +[out] +tmp/badext.pyi:0: error: Plugin must have .py extension + +[case testMissingPluginEntryPoint] +# flags: --config-file tmp/mypy.ini +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/noentry.py +[out] +/test-data/unit/plugins/noentry.py:0: error: Plugin does not define entry point function "plugin" + +[case testInvalidPluginEntryPointReturnValue] +# flags: --config-file tmp/mypy.ini +def f(): pass +f() +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/badreturn.py +[out] +/test-data/unit/plugins/badreturn.py:0: error: Type object expected as the return value of "plugin" (got None) + +[case testInvalidPluginEntryPointReturnValue2] +# flags: --config-file tmp/mypy.ini +def f(): pass +f() +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/badreturn2.py +[out] +/test-data/unit/plugins/badreturn2.py:0: error: Return value of "plugin" must be a subclass of "mypy.plugin.Plugin" diff --git a/test-data/unit/plugins/badreturn.py b/test-data/unit/plugins/badreturn.py new file mode 100644 index 000000000000..fd7430606dd6 --- /dev/null +++ b/test-data/unit/plugins/badreturn.py @@ -0,0 +1,2 @@ +def plugin(version): + pass diff --git a/test-data/unit/plugins/badreturn2.py b/test-data/unit/plugins/badreturn2.py new file mode 100644 index 000000000000..c7e0447841c1 --- /dev/null +++ b/test-data/unit/plugins/badreturn2.py @@ -0,0 +1,5 @@ +class MyPlugin: + pass + +def plugin(version): + return MyPlugin diff --git a/test-data/unit/plugins/fnplugin.py b/test-data/unit/plugins/fnplugin.py new file mode 100644 index 000000000000..d5027219a09f --- /dev/null +++ b/test-data/unit/plugins/fnplugin.py @@ -0,0 +1,13 @@ +from mypy.plugin import Plugin + +class MyPlugin(Plugin): + def get_function_hook(self, fullname): + if fullname == '__main__.f': + return my_hook + return None + +def my_hook(arg_types, args, inferred_return_type, named_generic_type): + return named_generic_type('builtins.int', []) + +def plugin(version): + return MyPlugin diff --git a/test-data/unit/plugins/noentry.py b/test-data/unit/plugins/noentry.py new file mode 100644 index 000000000000..c591ad11fd64 --- /dev/null +++ b/test-data/unit/plugins/noentry.py @@ -0,0 +1 @@ +# empty plugin diff --git a/test-data/unit/plugins/plugin2.py b/test-data/unit/plugins/plugin2.py new file mode 100644 index 000000000000..1584871fae1d --- /dev/null +++ b/test-data/unit/plugins/plugin2.py @@ -0,0 +1,13 @@ +from mypy.plugin import Plugin + +class Plugin2(Plugin): + def get_function_hook(self, fullname): + if fullname in ('__main__.f', '__main__.g'): + return str_hook + return None + +def str_hook(arg_types, args, inferred_return_type, named_generic_type): + return named_generic_type('builtins.str', []) + +def plugin(version): + return Plugin2