diff --git a/mypy/nodes.py b/mypy/nodes.py index d10bfafa484c..4d725e006b45 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2285,6 +2285,7 @@ class is generic then it will be a type constructor of higher kind. # Used to stash the names of the mro classes temporarily between # deserialization and fixup. See deserialize() for why. _mro_refs = None # type: Optional[List[str]] + bad_mro = False # Could not construct full MRO declared_metaclass = None # type: Optional[mypy.types.Instance] metaclass_type = None # type: Optional[mypy.types.Instance] diff --git a/mypy/options.py b/mypy/options.py index f120dba8715b..beaedbc4277d 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -3,8 +3,8 @@ import pprint import sys -from typing import Dict, List, Mapping, Optional, Pattern, Set, Tuple from typing_extensions import Final +from typing import Dict, List, Mapping, Optional, Pattern, Set, Tuple, Callable, Any from mypy import defaults from mypy.util import get_class_descriptors, replace_object_state @@ -262,6 +262,9 @@ def __init__(self) -> None: self.cache_map = {} # type: Dict[str, Tuple[str, str]] # Don't properly free objects on exit, just kill the current process. self.fast_exit = False + # Used to transform source code before parsing if not None + # TODO: Make the type precise (AnyStr -> AnyStr) + self.transform_source = None # type: Optional[Callable[[Any], Any]] # Print full path to each file in the report. self.show_absolute_path = False # type: bool diff --git a/mypy/parse.py b/mypy/parse.py index 149a0bbb6196..c39a2388028a 100644 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -18,6 +18,8 @@ def parse(source: Union[str, bytes], The python_version (major, minor) option determines the Python syntax variant. """ is_stub_file = fnam.endswith('.pyi') + if options.transform_source is not None: + source = options.transform_source(source) if options.python_version[0] >= 3 or is_stub_file: import mypy.fastparse return mypy.fastparse.parse(source, diff --git a/mypy/semanal.py b/mypy/semanal.py index e864367ee14e..31ad131335a0 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1490,8 +1490,7 @@ def configure_base_classes(self, # Calculate the MRO. if not self.verify_base_classes(defn): - # Give it an MRO consisting of just the class itself and object. - defn.info.mro = [defn.info, self.object_type().type] + self.set_dummy_mro(defn.info) return self.calculate_class_mro(defn, self.object_type) @@ -1519,6 +1518,11 @@ def configure_tuple_base_class(self, return base.partial_fallback + def set_dummy_mro(self, info: TypeInfo) -> None: + # Give it an MRO consisting of just the class itself and object. + info.mro = [info, self.object_type().type] + info.bad_mro = True + def calculate_class_mro(self, defn: ClassDef, obj_type: Optional[Callable[[], Instance]] = None) -> None: """Calculate method resolution order for a class. @@ -1530,9 +1534,9 @@ def calculate_class_mro(self, defn: ClassDef, try: calculate_mro(defn.info, obj_type) except MroError: - self.fail_blocker('Cannot determine consistent method resolution ' - 'order (MRO) for "%s"' % defn.name, defn) - defn.info.mro = [] + self.fail('Cannot determine consistent method resolution ' + 'order (MRO) for "%s"' % defn.name, defn) + self.set_dummy_mro(defn.info) # Allow plugins to alter the MRO to handle the fact that `def mro()` # on metaclasses permits MRO rewriting. if defn.fullname: @@ -1597,12 +1601,12 @@ def update_metaclass(self, defn: ClassDef) -> None: def verify_base_classes(self, defn: ClassDef) -> bool: info = defn.info + cycle = False for base in info.bases: baseinfo = base.type if self.is_base_class(info, baseinfo): - self.fail('Cycle in inheritance hierarchy', defn, blocker=True) - # Clear bases to forcefully get rid of the cycle. - info.bases = [] + self.fail('Cycle in inheritance hierarchy', defn) + cycle = True if baseinfo.fullname == 'builtins.bool': self.fail("'%s' is not a valid base class" % baseinfo.name, defn, blocker=True) @@ -1611,7 +1615,7 @@ def verify_base_classes(self, defn: ClassDef) -> bool: if dup: self.fail('Duplicate base class "%s"' % dup.name, defn, blocker=True) return False - return True + return not cycle def is_base_class(self, t: TypeInfo, s: TypeInfo) -> bool: """Determine if t is a base class of s (but do not use mro).""" diff --git a/mypy/stubdoc.py b/mypy/stubdoc.py index bc10fe171ca5..48c678e6af31 100644 --- a/mypy/stubdoc.py +++ b/mypy/stubdoc.py @@ -199,8 +199,12 @@ def infer_sig_from_docstring(docstr: str, name: str) -> Optional[List[FunctionSi state = DocStringParser(name) # Return all found signatures, even if there is a parse error after some are found. with contextlib.suppress(tokenize.TokenError): - for token in tokenize.tokenize(io.BytesIO(docstr.encode('utf-8')).readline): - state.add_token(token) + try: + tokens = tokenize.tokenize(io.BytesIO(docstr.encode('utf-8')).readline) + for token in tokens: + state.add_token(token) + except IndentationError: + return None sigs = state.get_signatures() def is_unique_args(sig: FunctionSig) -> bool: diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 5fb4807a8013..f7edbdf2d4c4 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -62,6 +62,7 @@ import mypy.parse import mypy.errors import mypy.traverser +import mypy.mixedtraverser import mypy.util from mypy import defaults from mypy.modulefinder import FindModuleCache, SearchPaths, BuildSource, default_lib_path @@ -69,20 +70,20 @@ Expression, IntExpr, UnaryExpr, StrExpr, BytesExpr, NameExpr, FloatExpr, MemberExpr, TupleExpr, ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr, ClassDef, MypyFile, Decorator, AssignmentStmt, TypeInfo, - IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, TempNode, + IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, TempNode, Block, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT ) from mypy.stubgenc import generate_stub_for_c_module from mypy.stubutil import ( - write_header, default_py2_interpreter, CantImport, generate_guarded, + default_py2_interpreter, CantImport, generate_guarded, walk_packages, find_module_path_and_all_py2, find_module_path_and_all_py3, - report_missing, fail_missing + report_missing, fail_missing, remove_misplaced_type_comments, common_dir_prefix ) from mypy.stubdoc import parse_all_signatures, find_unique_signatures, Sig from mypy.options import Options as MypyOptions from mypy.types import ( - Type, TypeStrVisitor, CallableType, AnyType, - UnboundType, NoneType, TupleType, TypeList, + Type, TypeStrVisitor, CallableType, UnboundType, NoneType, TupleType, TypeList, Instance, + AnyType ) from mypy.visitor import NodeVisitor from mypy.find_sources import create_source_list, InvalidSourceList @@ -91,15 +92,49 @@ from mypy.traverser import has_return_statement +# Common ways of naming package containing vendored modules. +VENDOR_PACKAGES = [ + 'packages', + 'vendor', + 'vendored', +] + +# Avoid some file names that are unnecessary or likely to cause trouble (\n for end of path). +BLACKLIST = [ + '/six.py\n', # Likely vendored six; too dynamic for us to handle + '/vendored/', # Vendored packages + '/vendor/', # Vendored packages +] + +# Special-cased names that are implicitly exported from the stub (from m import y as y). +EXTRA_EXPORTED = { + 'pyasn1_modules.rfc2437.univ', + 'pyasn1_modules.rfc2459.char', + 'pyasn1_modules.rfc2459.univ', + 'elasticsearch.client.utils._make_path', +} + + class Options: """Represents stubgen options. This class is mutable to simplify testing. """ - def __init__(self, pyversion: Tuple[int, int], no_import: bool, doc_dir: str, - search_path: List[str], interpreter: str, parse_only: bool, ignore_errors: bool, - include_private: bool, output_dir: str, modules: List[str], packages: List[str], - files: List[str]) -> None: + def __init__(self, + pyversion: Tuple[int, int], + no_import: bool, + doc_dir: str, + search_path: List[str], + interpreter: str, + parse_only: bool, + ignore_errors: bool, + include_private: bool, + output_dir: str, + modules: List[str], + packages: List[str], + files: List[str], + verbose: bool, + quiet: bool) -> None: # See parse_options for descriptions of the flags. self.pyversion = pyversion self.no_import = no_import @@ -114,6 +149,8 @@ def __init__(self, pyversion: Tuple[int, int], no_import: bool, doc_dir: str, self.modules = modules self.packages = packages self.files = files + self.verbose = verbose + self.quiet = quiet class StubSource(BuildSource): @@ -323,7 +360,8 @@ def import_lines(self) -> List[str]: # We can already generate the import line if name in self.reverse_alias: name, alias = self.reverse_alias[name], name - result.append("import {} as {}\n".format(self.direct_imports[name], alias)) + source = self.direct_imports.get(name, 'FIXME') + result.append("import {} as {}\n".format(source, alias)) elif name in self.reexports: assert '.' not in name # Because reexports only has nonqualified names result.append("import {} as {}\n".format(name, name)) @@ -336,7 +374,78 @@ def import_lines(self) -> List[str]: return result +def find_defined_names(file: MypyFile) -> Set[str]: + finder = DefinitionFinder() + file.accept(finder) + return finder.names + + +class DefinitionFinder(mypy.traverser.TraverserVisitor): + """Find names of things defined at the top level of a module.""" + + # TODO: Assignment statements etc. + + def __init__(self) -> None: + # Short names of things defined at the top level. + self.names = set() # type: Set[str] + + def visit_class_def(self, o: ClassDef) -> None: + # Don't recurse into classes, as we only keep track of top-level definitions. + self.names.add(o.name) + + def visit_func_def(self, o: FuncDef) -> None: + # Don't recurse, as we only keep track of top-level definitions. + self.names.add(o.name) + + +def find_referenced_names(file: MypyFile) -> Set[str]: + finder = ReferenceFinder() + file.accept(finder) + return finder.refs + + +class ReferenceFinder(mypy.mixedtraverser.MixedTraverserVisitor): + """Find all name references (both local and global).""" + + # TODO: Filter out local variable and class attribute references + + def __init__(self) -> None: + # Short names of things defined at the top level. + self.refs = set() # type: Set[str] + + def visit_block(self, block: Block) -> None: + if not block.is_unreachable: + super().visit_block(block) + + def visit_name_expr(self, e: NameExpr) -> None: + self.refs.add(e.name) + + def visit_instance(self, t: Instance) -> None: + self.add_ref(t.type.fullname) + super().visit_instance(t) + + def visit_unbound_type(self, t: UnboundType) -> None: + if t.name: + self.add_ref(t.name) + + def visit_tuple_type(self, t: TupleType) -> None: + # Ignore fallback + for item in t.items: + item.accept(self) + + def visit_callable_type(self, t: CallableType) -> None: + # Ignore fallback + for arg in t.arg_types: + arg.accept(self) + t.ret_type.accept(self) + + def add_ref(self, fullname: str) -> None: + self.refs.add(fullname.split('.')[-1]) + + class StubGenerator(mypy.traverser.TraverserVisitor): + """Generate stub text from a mypy AST.""" + def __init__(self, _all_: Optional[List[str]], pyversion: Tuple[int, int], include_private: bool = False, analyzed: bool = False) -> None: # Best known value of __all__. @@ -358,13 +467,22 @@ def __init__(self, _all_: Optional[List[str]], pyversion: Tuple[int, int], self.analyzed = analyzed # Add imports that could be implicitly generated self.import_tracker.add_import_from("collections", [("namedtuple", None)]) - typing_imports = "Any Optional TypeVar".split() - self.import_tracker.add_import_from("typing", [(t, None) for t in typing_imports]) # Names in __all__ are required for name in _all_ or (): self.import_tracker.reexport(name) + self.defined_names = set() # type: Set[str] def visit_mypy_file(self, o: MypyFile) -> None: + self.module = o.fullname + self.defined_names = find_defined_names(o) + self.referenced_names = find_referenced_names(o) + typing_imports = ["Any", "Optional", "TypeVar"] + for t in typing_imports: + if t not in self.defined_names: + alias = None + else: + alias = '_' + t + self.import_tracker.add_import_from("typing", [(t, alias)]) super().visit_mypy_file(o) undefined_names = [name for name in self._all_ or [] if name not in self._toplevel_names] @@ -376,7 +494,7 @@ def visit_mypy_file(self, o: MypyFile) -> None: self.add('# %s\n' % name) def visit_func_def(self, o: FuncDef, is_abstract: bool = False) -> None: - if self.is_private_name(o.name): + if self.is_private_name(o.name, o.fullname): return if self.is_not_in_all(o.name): return @@ -412,7 +530,7 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False) -> None: and not is_self_arg and not is_cls_arg): self.add_typing_import("Any") - annotation = ": Any" + annotation = ": {}".format(self.typing_name("Any")) elif annotated_type and not is_self_arg: annotation = ": {}".format(self.print_annotation(annotated_type)) else: @@ -440,7 +558,7 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False) -> None: retname = self.print_annotation(o.unanalyzed_type.ret_type) elif isinstance(o, FuncDef) and o.is_abstract: # Always assume abstract methods return Any unless explicitly annotated. - retname = 'Any' + retname = self.typing_name('Any') self.add_typing_import("Any") elif o.name == '__init__' or not has_return_statement(o) and not is_abstract: retname = 'None' @@ -453,56 +571,83 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False) -> None: self._state = FUNC def visit_decorator(self, o: Decorator) -> None: - if self.is_private_name(o.func.name): + if self.is_private_name(o.func.name, o.func.fullname): return is_abstract = False for decorator in o.original_decorators: if isinstance(decorator, NameExpr): - if decorator.name in ('property', - 'staticmethod', - 'classmethod'): - self.add_decorator('%s@%s\n' % (self._indent, decorator.name)) - elif self.import_tracker.module_for.get(decorator.name) in ('asyncio', - 'asyncio.coroutines', - 'types'): - self.add_coroutine_decorator(o.func, decorator.name, decorator.name) - elif (self.import_tracker.module_for.get(decorator.name) == 'abc' and - (decorator.name == 'abstractmethod' or - self.import_tracker.reverse_alias.get(decorator.name) == 'abstractmethod')): - self.add_decorator('%s@%s\n' % (self._indent, decorator.name)) - self.import_tracker.require_name(decorator.name) + if self.process_name_expr_decorator(decorator, o): is_abstract = True elif isinstance(decorator, MemberExpr): - if decorator.name == 'setter' and isinstance(decorator.expr, NameExpr): - self.add_decorator('%s@%s.setter\n' % (self._indent, decorator.expr.name)) - elif (isinstance(decorator.expr, NameExpr) and - (decorator.expr.name == 'abc' or - self.import_tracker.reverse_alias.get('abc')) and - decorator.name == 'abstractmethod'): - self.import_tracker.require_name(decorator.expr.name) - self.add_decorator('%s@%s.%s\n' % - (self._indent, decorator.expr.name, decorator.name)) + if self.process_member_expr_decorator(decorator, o): is_abstract = True - elif decorator.name == 'coroutine': - if (isinstance(decorator.expr, MemberExpr) and - decorator.expr.name == 'coroutines' and - isinstance(decorator.expr.expr, NameExpr) and - (decorator.expr.expr.name == 'asyncio' or - self.import_tracker.reverse_alias.get(decorator.expr.expr.name) == - 'asyncio')): - self.add_coroutine_decorator(o.func, - '%s.coroutines.coroutine' % - (decorator.expr.expr.name,), - decorator.expr.expr.name) - elif (isinstance(decorator.expr, NameExpr) and - (decorator.expr.name in ('asyncio', 'types') or - self.import_tracker.reverse_alias.get(decorator.expr.name) in - ('asyncio', 'asyncio.coroutines', 'types'))): - self.add_coroutine_decorator(o.func, - decorator.expr.name + '.coroutine', - decorator.expr.name) self.visit_func_def(o.func, is_abstract=is_abstract) + def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> bool: + """Process a function decorator of form @foo. + + Only preserve certain special decorators such as @abstractmethod. + + Return True if the decorator makes a method abstract. + """ + is_abstract = False + name = expr.name + if name in ('property', 'staticmethod', 'classmethod'): + self.add_decorator(name) + elif self.import_tracker.module_for.get(name) in ('asyncio', + 'asyncio.coroutines', + 'types'): + self.add_coroutine_decorator(context.func, name, name) + elif any(self.refers_to_fullname(name, target) + for target in ('abc.abstractmethod', 'abc.abstractproperty')): + self.add_decorator(name) + self.import_tracker.require_name(name) + is_abstract = True + return is_abstract + + def refers_to_fullname(self, name: str, fullname: str) -> bool: + module, short = fullname.rsplit('.', 1) + return (self.import_tracker.module_for.get(name) == module and + (name == short or + self.import_tracker.reverse_alias.get(name) == short)) + + def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) -> bool: + """Process a function decorator of form @foo.bar. + + Only preserve certain special decorators such as @abstractmethod. + + Return True if the decorator makes a method abstract. + """ + is_abstract = False + if expr.name == 'setter' and isinstance(expr.expr, NameExpr): + self.add_decorator('%s.setter' % expr.expr.name) + elif (isinstance(expr.expr, NameExpr) and + (expr.expr.name == 'abc' or + self.import_tracker.reverse_alias.get('abc')) and + expr.name in ('abstractmethod', 'abstractproperty')): + self.import_tracker.require_name(expr.expr.name) + self.add_decorator('%s.%s' % (expr.expr.name, expr.name)) + is_abstract = True + elif expr.name == 'coroutine': + if (isinstance(expr.expr, MemberExpr) and + expr.expr.name == 'coroutines' and + isinstance(expr.expr.expr, NameExpr) and + (expr.expr.expr.name == 'asyncio' or + self.import_tracker.reverse_alias.get(expr.expr.expr.name) == + 'asyncio')): + self.add_coroutine_decorator(context.func, + '%s.coroutines.coroutine' % + (expr.expr.expr.name,), + expr.expr.expr.name) + elif (isinstance(expr.expr, NameExpr) and + (expr.expr.name in ('asyncio', 'types') or + self.import_tracker.reverse_alias.get(expr.expr.name) in + ('asyncio', 'asyncio.coroutines', 'types'))): + self.add_coroutine_decorator(context.func, + expr.expr.name + '.coroutine', + expr.expr.name) + return is_abstract + def visit_class_def(self, o: ClassDef) -> None: sep = None # type: Optional[int] if not self._indent and self._state != EMPTY: @@ -554,6 +699,12 @@ def get_base_types(self, cdef: ClassDef) -> List[str]: base_types.append(base.accept(p)) return base_types + def visit_block(self, o: Block) -> None: + # Unreachable statements may be partially uninitialized and that may + # cause trouble. + if not o.is_unreachable: + super().visit_block(o) + def visit_assignment_stmt(self, o: AssignmentStmt) -> None: foundl = [] @@ -682,9 +833,31 @@ def visit_import_all(self, o: ImportAll) -> None: def visit_import_from(self, o: ImportFrom) -> None: exported_names = set() # type: Set[str] - self.import_tracker.add_import_from('.' * o.relative + o.id, o.names) - self._vars[-1].extend(alias or name for name, alias in o.names) - for name, alias in o.names: + import_names = [] + module, relative = self.translate_module_name(o.id, o.relative) + if module == '__future__': + return # Not preserved + for name, as_name in o.names: + if name == 'six': + # Vendored six -- translate into plain 'import six'. + self.visit_import(Import([('six', None)])) + continue + exported = False + if as_name is None and self.module and (self.module + '.' + name) in EXTRA_EXPORTED: + exported = True + if (as_name is None and name not in self.referenced_names and not self._all_ + and module not in ('abc', 'typing', 'asyncio')): + # An imported name that is never referenced in the module is assumed to be + # exported, unless there is an explicit __all__. Note that we need to special + # case 'abc' since some references are deleted during semantic analysis. + exported = True + if exported: + self.import_tracker.reexport(name) + as_name = name + import_names.append((name, as_name)) + self.import_tracker.add_import_from('.' * relative + module, import_names) + self._vars[-1].extend(alias or name for name, alias in import_names) + for name, alias in import_names: self.record_name(alias or name) if self._all_: @@ -694,14 +867,22 @@ def visit_import_from(self, o: ImportFrom) -> None: exported_names.update(names) else: # Include import from targets that import from a submodule of a package. - if o.relative: + if relative: sub_names = [name for name, alias in o.names if alias is None] exported_names.update(sub_names) - if o.id: + if module: for name in sub_names: self.import_tracker.require_name(name) + def translate_module_name(self, module: str, relative: int) -> Tuple[str, int]: + for pkg in VENDOR_PACKAGES: + for alt in 'six', 'six.moves': + if (module.endswith('.{}.{}'.format(pkg, alt)) + or (module == '{}.{}'.format(pkg, alt) and relative)): + return alt, 0 + return module, relative + def visit_import(self, o: Import) -> None: for id, as_id in o.ids: self.import_tracker.add_import(id, as_id) @@ -744,17 +925,27 @@ def add(self, string: str) -> None: """Add text to generated stub.""" self._output.append(string) - def add_decorator(self, string: str) -> None: - self._decorators.append(string) + def add_decorator(self, name: str) -> None: + if not self._indent and self._state not in (EMPTY, FUNC): + self._decorators.append('\n') + self._decorators.append('%s@%s\n' % (self._indent, name)) def clear_decorators(self) -> None: self._decorators.clear() + def typing_name(self, name: str) -> str: + if name in self.defined_names: + # Avoid name clash between name from typing and a name defined in stub. + return '_' + name + else: + return name + def add_typing_import(self, name: str) -> None: """Add a name to be imported from typing, unless it's imported already. The import will be internal to the stub. """ + name = self.typing_name(name) self.import_tracker.require_name(name) def add_import_line(self, line: str) -> None: @@ -764,9 +955,7 @@ def add_import_line(self, line: str) -> None: def add_coroutine_decorator(self, func: FuncDef, name: str, require_name: str) -> None: func.is_awaitable_coroutine = True - if not self._indent and self._state not in (EMPTY, FUNC): - self.add_decorator('\n') - self.add_decorator('%s@%s\n' % (self._indent, name)) + self.add_decorator(name) self.import_tracker.require_name(require_name) def output(self) -> str: @@ -786,9 +975,11 @@ def is_not_in_all(self, name: str) -> bool: return self.is_top_level() and name not in self._all_ return False - def is_private_name(self, name: str) -> bool: + def is_private_name(self, name: str, fullname: Optional[str] = None) -> bool: if self._include_private: return False + if fullname in EXTRA_EXPORTED: + return False return name.startswith('_') and (not name.endswith('__') or name in ('__all__', '__author__', @@ -824,9 +1015,10 @@ def get_str_type_of_node(self, rvalue: Expression, isinstance(rvalue, NameExpr) and rvalue.name == 'None': self.add_typing_import('Optional') self.add_typing_import('Any') - return 'Optional[Any]' + return '{}[{}]'.format(self.typing_name('Optional'), + self.typing_name('Any')) self.add_typing_import('Any') - return 'Any' + return self.typing_name('Any') def print_annotation(self, t: Type) -> str: printer = AnnotationPrinter(self) @@ -880,6 +1072,22 @@ def get_qualified_name(o: Expression) -> str: return ERROR_MARKER +def remove_blacklisted_modules(modules: List[StubSource]) -> List[StubSource]: + return [module for module in modules + if module.path is None or not is_blacklisted_path(module.path)] + + +def is_blacklisted_path(path: str) -> bool: + return any(substr in (normalize_path_separators(path) + '\n') + for substr in BLACKLIST) + + +def normalize_path_separators(path: str) -> str: + if sys.platform == 'win32': + return path.replace('\\', '/') + return path + + def collect_build_targets(options: Options, mypy_opts: MypyOptions) -> Tuple[List[StubSource], List[StubSource]]: """Collect files for which we need to generate stubs. @@ -898,7 +1106,9 @@ def collect_build_targets(options: Options, mypy_opts: MypyOptions) -> Tuple[Lis py_modules, c_modules = find_module_paths_using_imports(options.modules, options.packages, options.interpreter, - options.pyversion) + options.pyversion, + options.verbose, + options.quiet) else: # Use mypy native source collection for files and directories. try: @@ -908,31 +1118,39 @@ def collect_build_targets(options: Options, mypy_opts: MypyOptions) -> Tuple[Lis py_modules = [StubSource(m.module, m.path) for m in source_list] c_modules = [] + py_modules = remove_blacklisted_modules(py_modules) + return py_modules, c_modules -def find_module_paths_using_imports(modules: List[str], packages: List[str], +def find_module_paths_using_imports(modules: List[str], + packages: List[str], interpreter: str, pyversion: Tuple[int, int], - quiet: bool = True) -> Tuple[List[StubSource], - List[StubSource]]: + verbose: bool, + quiet: bool) -> Tuple[List[StubSource], + List[StubSource]]: """Find path and runtime value of __all__ (if possible) for modules and packages. This function uses runtime Python imports to get the information. """ py_modules = [] # type: List[StubSource] c_modules = [] # type: List[StubSource] - modules = modules + list(walk_packages(packages)) + found = list(walk_packages(packages, verbose)) + modules = modules + found + modules = [mod for mod in modules if not is_test_module(mod)] # We don't want to run any tests for mod in modules: try: if pyversion[0] == 2: result = find_module_path_and_all_py2(mod, interpreter) else: - result = find_module_path_and_all_py3(mod) + result = find_module_path_and_all_py3(mod, verbose) except CantImport as e: + tb = traceback.format_exc() + if verbose: + sys.stdout.write(tb) if not quiet: - traceback.print_exc() - report_missing(mod, e.message) + report_missing(mod, e.message, tb) continue if not result: c_modules.append(StubSource(mod)) @@ -942,6 +1160,26 @@ def find_module_paths_using_imports(modules: List[str], packages: List[str], return py_modules, c_modules +def is_test_module(module: str) -> bool: + """Does module look like a test module?""" + if module.endswith(( + '.tests', + '.test', + '.testing', + '_tests', + '.conftest', + 'test_util', + 'test_utils', + 'test_base', + )): + return True + if module.split('.')[-1].startswith('test_'): + return True + if '.tests.' in module or '.test.' in module or '.testing.' in module: + return True + return False + + def find_module_paths_using_search(modules: List[str], packages: List[str], search_path: List[str], pyversion: Tuple[int, int]) -> List[StubSource]: @@ -966,6 +1204,9 @@ def find_module_paths_using_search(modules: List[str], packages: List[str], fail_missing(package) sources = [StubSource(m.module, m.path) for m in p_result] result.extend(sources) + + result = [m for m in result if not is_test_module(m.module)] + return result @@ -977,6 +1218,8 @@ def mypy_options(stubgen_options: Options) -> MypyOptions: options.ignore_errors = True options.semantic_analysis_only = True options.python_version = stubgen_options.pyversion + options.show_traceback = True + options.transform_source = remove_misplaced_type_comments return options @@ -1003,6 +1246,8 @@ def parse_source_file(mod: StubSource, mypy_options: MypyOptions) -> None: def generate_asts_for_modules(py_modules: List[StubSource], parse_only: bool, mypy_options: MypyOptions) -> None: """Use mypy to parse (and optionally analyze) source files.""" + if not py_modules: + return # Nothing to do here, but there may be C modules if parse_only: for mod in py_modules: parse_source_file(mod, mypy_options) @@ -1024,8 +1269,7 @@ def generate_stub_from_ast(mod: StubSource, target: str, parse_only: bool = False, pyversion: Tuple[int, int] = defaults.PYTHON3_VERSION, - include_private: bool = False, - add_header: bool = True) -> None: + include_private: bool = False) -> None: """Use analysed (or just parsed) AST to generate type stub for single file. If directory for target doesn't exist it will created. Existing stub @@ -1043,8 +1287,6 @@ def generate_stub_from_ast(mod: StubSource, if subdir and not os.path.isdir(subdir): os.makedirs(subdir) with open(target, 'w') as file: - if add_header: - write_header(file, mod.module, pyversion=pyversion) file.write(''.join(gen.output())) @@ -1066,9 +1308,7 @@ def collect_docs_signatures(doc_dir: str) -> Tuple[Dict[str, str], Dict[str, str return sigs, class_sigs -def generate_stubs(options: Options, - # additional args for testing - quiet: bool = False, add_header: bool = True) -> None: +def generate_stubs(options: Options) -> None: """Main entry point for the program.""" mypy_opts = mypy_options(options) py_modules, c_modules = collect_build_targets(options, mypy_opts) @@ -1080,6 +1320,7 @@ def generate_stubs(options: Options, # Use parsed sources to generate stubs for Python modules. generate_asts_for_modules(py_modules, options.parse_only, mypy_opts) + files = [] for mod in py_modules: assert mod.path is not None, "Not found module was not skipped" target = mod.module.replace('.', '/') @@ -1088,18 +1329,30 @@ def generate_stubs(options: Options, else: target += '.pyi' target = os.path.join(options.output_dir, target) - with generate_guarded(mod.module, target, options.ignore_errors, quiet): + files.append(target) + with generate_guarded(mod.module, target, options.ignore_errors, options.verbose): generate_stub_from_ast(mod, target, options.parse_only, options.pyversion, - options.include_private, add_header) + options.include_private) # Separately analyse C modules using different logic. for mod in c_modules: - target = mod.module.replace('.', '/') + '.pyi' + if any(py_mod.module.startswith(mod.module + '.') + for py_mod in py_modules + c_modules): + target = mod.module.replace('.', '/') + '/__init__.pyi' + else: + target = mod.module.replace('.', '/') + '.pyi' target = os.path.join(options.output_dir, target) - with generate_guarded(mod.module, target, options.ignore_errors, quiet): - generate_stub_for_c_module(mod.module, target, sigs=sigs, class_sigs=class_sigs, - add_header=add_header) + files.append(target) + with generate_guarded(mod.module, target, options.ignore_errors, options.verbose): + generate_stub_for_c_module(mod.module, target, sigs=sigs, class_sigs=class_sigs) + num_modules = len(py_modules) + len(c_modules) + if not options.quiet and num_modules > 0: + print('Processed %d modules' % num_modules) + if len(files) == 1: + print('Generated %s' % files[0]) + else: + print('Generated files under %s' % common_dir_prefix(files) + os.sep) HEADER = """%(prog)s [-h] [--py2] [more options, see -h] @@ -1132,6 +1385,10 @@ def parse_options(args: List[str]) -> Options: parser.add_argument('--include-private', action='store_true', help="generate stubs for objects and members considered private " "(single leading underscore and no trailing underscores)") + parser.add_argument('-v', '--verbose', action='store_true', + help="show more verbose messages") + parser.add_argument('-q', '--quiet', action='store_true', + help="show fewer messages") parser.add_argument('--doc-dir', metavar='PATH', default='', help="use .rst documentation in PATH (this may result in " "better stubs in some cases; consider setting this to " @@ -1160,6 +1417,8 @@ def parse_options(args: List[str]) -> Options: ns.interpreter = sys.executable if pyversion[0] == 3 else default_py2_interpreter() if ns.modules + ns.packages and ns.files: parser.error("May only specify one of: modules/packages or files.") + if ns.quiet and ns.verbose: + parser.error('Cannot specify both quiet and verbose messages') # Create the output folder if it doesn't already exist. if not os.path.exists(ns.output_dir): @@ -1176,7 +1435,9 @@ def parse_options(args: List[str]) -> Options: output_dir=ns.output_dir, modules=ns.modules, packages=ns.packages, - files=ns.files) + files=ns.files, + verbose=ns.verbose, + quiet=ns.quiet) def main() -> None: diff --git a/mypy/stubgenc.py b/mypy/stubgenc.py index cbe1575379dd..7de5f450aa6d 100755 --- a/mypy/stubgenc.py +++ b/mypy/stubgenc.py @@ -11,7 +11,7 @@ from typing import List, Dict, Tuple, Optional, Mapping, Any, Set from types import ModuleType -from mypy.stubutil import write_header, is_c_module +from mypy.stubutil import is_c_module from mypy.stubdoc import ( infer_sig_from_docstring, infer_prop_type_from_docstring, ArgSig, infer_arg_sig_from_docstring, FunctionSig @@ -20,7 +20,6 @@ def generate_stub_for_c_module(module_name: str, target: str, - add_header: bool = True, sigs: Optional[Dict[str, str]] = None, class_sigs: Optional[Dict[str, str]] = None) -> None: """Generate stub for C module. @@ -76,8 +75,6 @@ def generate_stub_for_c_module(module_name: str, output.append(line) output = add_typing_import(output) with open(target, 'w') as file: - if add_header: - write_header(file, module_name) for line in output: file.write('%s\n' % line) @@ -339,33 +336,52 @@ def is_skipped_attribute(attr: str) -> bool: def infer_method_sig(name: str) -> List[ArgSig]: + args = None # type: Optional[List[ArgSig]] if name.startswith('__') and name.endswith('__'): name = name[2:-2] if name in ('hash', 'iter', 'next', 'sizeof', 'copy', 'deepcopy', 'reduce', 'getinitargs', - 'int', 'float', 'trunc', 'complex', 'bool'): - return [] - if name == 'getitem': - return [ArgSig(name='index')] - if name == 'setitem': - return [ArgSig(name='index'), + 'int', 'float', 'trunc', 'complex', 'bool', 'abs', 'bytes', 'dir', 'len', + 'reversed', 'round', 'index', 'enter'): + args = [] + elif name == 'getitem': + args = [ArgSig(name='index')] + elif name == 'setitem': + args = [ArgSig(name='index'), ArgSig(name='object')] - if name in ('delattr', 'getattr'): - return [ArgSig(name='name')] - if name == 'setattr': - return [ArgSig(name='name'), + elif name in ('delattr', 'getattr'): + args = [ArgSig(name='name')] + elif name == 'setattr': + args = [ArgSig(name='name'), ArgSig(name='value')] - if name == 'getstate': - return [] - if name == 'setstate': - return [ArgSig(name='state')] - if name in ('eq', 'ne', 'lt', 'le', 'gt', 'ge', - 'add', 'radd', 'sub', 'rsub', 'mul', 'rmul', - 'mod', 'rmod', 'floordiv', 'rfloordiv', 'truediv', 'rtruediv', - 'divmod', 'rdivmod', 'pow', 'rpow'): - return [ArgSig(name='other')] - if name in ('neg', 'pos'): - return [] - return [ - ArgSig(name='*args'), - ArgSig(name='**kwargs') - ] + elif name == 'getstate': + args = [] + elif name == 'setstate': + args = [ArgSig(name='state')] + elif name in ('eq', 'ne', 'lt', 'le', 'gt', 'ge', + 'add', 'radd', 'sub', 'rsub', 'mul', 'rmul', + 'mod', 'rmod', 'floordiv', 'rfloordiv', 'truediv', 'rtruediv', + 'divmod', 'rdivmod', 'pow', 'rpow', + 'xor', 'rxor', 'or', 'ror', 'and', 'rand', 'lshift', 'rlshift', + 'rshift', 'rrshift', + 'contains', 'delitem', + 'iadd', 'iand', 'ifloordiv', 'ilshift', 'imod', 'imul', 'ior', + 'ipow', 'irshift', 'isub', 'itruediv', 'ixor'): + args = [ArgSig(name='other')] + elif name in ('neg', 'pos', 'invert'): + args = [] + elif name == 'get': + args = [ArgSig(name='instance'), + ArgSig(name='owner')] + elif name == 'set': + args = [ArgSig(name='instance'), + ArgSig(name='value')] + elif name == 'reduce_ex': + args = [ArgSig(name='protocol')] + elif name == 'exit': + args = [ArgSig(name='type'), + ArgSig(name='value'), + ArgSig(name='traceback')] + if args is None: + args = [ArgSig(name='*args'), + ArgSig(name='**kwargs')] + return [ArgSig(name='self')] + args diff --git a/mypy/stubutil.py b/mypy/stubutil.py index 0f2314bc8504..ad3c63d8f10a 100644 --- a/mypy/stubutil.py +++ b/mypy/stubutil.py @@ -7,10 +7,18 @@ import pkgutil import importlib import subprocess +import re from types import ModuleType from contextlib import contextmanager -from typing import Optional, Tuple, List, IO, Iterator +from typing import Optional, Tuple, List, Iterator, Union +from typing_extensions import overload + + +# Modules that may fail when imported, or that may have side effects. +NOT_IMPORTABLE_MODULES = { + 'tensorflow.tools.pip_package.setup', +} class CantImport(Exception): @@ -20,18 +28,11 @@ def __init__(self, module: str, message: str): def is_c_module(module: ModuleType) -> bool: - return ('__file__' not in module.__dict__ or - os.path.splitext(module.__dict__['__file__'])[-1] in ['.so', '.pyd']) - - -def write_header(file: IO[str], module_name: Optional[str] = None, - pyversion: Tuple[int, int] = (3, 5)) -> None: - """Write a header to file indicating this file is auto-generated by stubgen.""" - if module_name: - file.write('# Stubs for %s (Python %s)\n' % (module_name, pyversion[0])) - file.write( - '#\n' - '# NOTE: This dynamically typed stub was automatically generated by stubgen.\n\n') + if module.__dict__.get('__file__') is None: + # Could be a namespace package. These must be handled through + # introspection, since there is no source file. + return True + return os.path.splitext(module.__dict__['__file__'])[-1] in ['.so', '.pyd'] def default_py2_interpreter() -> str: @@ -51,7 +52,7 @@ def default_py2_interpreter() -> str: "please use the --python-executable option") -def walk_packages(packages: List[str]) -> Iterator[str]: +def walk_packages(packages: List[str], verbose: bool = False) -> Iterator[str]: """Iterates through all packages and sub-packages in the given list. This uses runtime imports to find both Python and C modules. For Python packages @@ -61,6 +62,11 @@ def walk_packages(packages: List[str]) -> Iterator[str]: all modules imported in the package that have matching names. """ for package_name in packages: + if package_name in NOT_IMPORTABLE_MODULES: + print('%s: Skipped (blacklisted)' % package_name) + continue + if verbose: + print('Trying to import %r for runtime introspection' % package_name) try: package = importlib.import_module(package_name) except Exception: @@ -80,7 +86,7 @@ def walk_packages(packages: List[str]) -> Iterator[str]: if inspect.ismodule(val) and val.__name__ == package.__name__ + "." + name] # Recursively iterate through the subpackages - for submodule in walk_packages(subpackages): + for submodule in walk_packages(subpackages, verbose): yield submodule # It's a module inside a package. There's nothing else to walk/yield. else: @@ -118,13 +124,19 @@ def find_module_path_and_all_py2(module: str, return module_path, module_all -def find_module_path_and_all_py3(module: str) -> Optional[Tuple[str, Optional[List[str]]]]: +def find_module_path_and_all_py3(module: str, + verbose: bool) -> Optional[Tuple[str, Optional[List[str]]]]: """Find module and determine __all__ for a Python 3 module. Return None if the module is a C module. Return (module_path, __all__) if it is a Python module. Raise CantImport if import failed. """ + if module in NOT_IMPORTABLE_MODULES: + raise CantImport(module, '') + # TODO: Support custom interpreters. + if verbose: + print('Trying to import %r for runtime introspection' % module) try: mod = importlib.import_module(module) except Exception as e: @@ -139,11 +151,13 @@ def find_module_path_and_all_py3(module: str) -> Optional[Tuple[str, Optional[Li @contextmanager def generate_guarded(mod: str, target: str, - ignore_errors: bool = True, quiet: bool = False) -> Iterator[None]: + ignore_errors: bool = True, verbose: bool = False) -> Iterator[None]: """Ignore or report errors during stub generation. Optionally report success. """ + if verbose: + print('Processing %s' % mod) try: yield except Exception as e: @@ -153,15 +167,74 @@ def generate_guarded(mod: str, target: str, # --ignore-errors was passed print("Stub generation failed for", mod, file=sys.stderr) else: - if not quiet: + if verbose: print('Created %s' % target) -def report_missing(mod: str, message: Optional[str] = '') -> None: +PY2_MODULES = {'cStringIO', 'urlparse', 'collections.UserDict'} + + +def report_missing(mod: str, message: Optional[str] = '', traceback: str = '') -> None: if message: message = ' with error: ' + message - print('Failed to import {}{}; skipping it'.format(mod, message)) + print('{}: Failed to import, skipping{}'.format(mod, message)) + m = re.search(r"ModuleNotFoundError: No module named '([^']*)'", traceback) + if m: + missing_module = m.group(1) + if missing_module in PY2_MODULES: + print('note: Try --py2 for Python 2 mode') def fail_missing(mod: str) -> None: raise SystemExit("Can't find module '{}' (consider using --search-path)".format(mod)) + + +@overload +def remove_misplaced_type_comments(source: bytes) -> bytes: ... + + +@overload +def remove_misplaced_type_comments(source: str) -> str: ... + + +def remove_misplaced_type_comments(source: Union[str, bytes]) -> Union[str, bytes]: + """Remove comments from source that could be understood as misplaced type comments. + + Normal comments may look like misplaced type comments, and since they cause blocking + parse errors, we want to avoid them. + """ + if isinstance(source, bytes): + # This gives us a 1-1 character code mapping, so it's roundtrippable. + text = source.decode('latin1') + else: + text = source + + # Remove something that looks like a variable type comment but that's by itself + # on a line, as it will often generate a parse error (unless it's # type: ignore). + text = re.sub(r'^[ \t]*# +type: +["\'a-zA-Z_].*$', '', text, flags=re.MULTILINE) + + # Remove something that looks like a function type comment after docstring, + # which will result in a parse error. + text = re.sub(r'""" *\n[ \t\n]*# +type: +\(.*$', '"""\n', text, flags=re.MULTILINE) + text = re.sub(r"''' *\n[ \t\n]*# +type: +\(.*$", "'''\n", text, flags=re.MULTILINE) + + # Remove something that looks like a badly formed function type comment. + text = re.sub(r'^[ \t]*# +type: +\([^()]+(\)[ \t]*)?$', '', text, flags=re.MULTILINE) + + if isinstance(source, bytes): + return text.encode('latin1') + else: + return text + + +def common_dir_prefix(paths: List[str]) -> str: + if not paths: + return '.' + cur = os.path.dirname(paths[0]) + for path in paths[1:]: + while True: + path = os.path.dirname(path) + if (cur + '/').startswith(path + '/'): + cur = path + break + return cur or '.' diff --git a/mypy/test/teststubgen.py b/mypy/test/teststubgen.py index 438bfdeebe53..4d3eaef112c5 100644 --- a/mypy/test/teststubgen.py +++ b/mypy/test/teststubgen.py @@ -15,9 +15,9 @@ from mypy.errors import CompileError from mypy.stubgen import ( generate_stubs, parse_options, Options, collect_build_targets, - mypy_options + mypy_options, is_blacklisted_path, is_test_module ) -from mypy.stubutil import walk_packages +from mypy.stubutil import walk_packages, remove_misplaced_type_comments, common_dir_prefix from mypy.stubgenc import generate_c_type_stub, infer_method_sig, generate_c_function_stub from mypy.stubdoc import ( parse_signature, parse_all_signatures, build_signature, find_unique_signatures, @@ -27,6 +27,8 @@ class StubgenCmdLineSuite(Suite): + """Test cases for processing command-line options and finding files.""" + def test_files_found(self) -> None: current = os.getcwd() with tempfile.TemporaryDirectory() as tmp: @@ -112,6 +114,8 @@ def test_walk_packages(self) -> None: class StubgenUtilSuite(Suite): + """Unit tests for stubgen utility functions.""" + def test_parse_signature(self) -> None: self.assert_parse_signature('func()', ('func', [], [])) @@ -253,6 +257,13 @@ def test_infer_sig_from_docstring_duplicate_args(self) -> None: [FunctionSig(name='func', args=[ArgSig(name='x'), ArgSig(name='y')], ret_type='int')]) + def test_infer_sig_from_docstring_bad_indentation(self) -> None: + assert_equal(infer_sig_from_docstring(""" + x + x + x + """, 'func'), None) + def test_infer_arg_sig_from_docstring(self) -> None: assert_equal(infer_arg_sig_from_docstring("(*args, **kwargs)"), [ArgSig(name='*args'), ArgSig(name='**kwargs')]) @@ -270,8 +281,228 @@ def test_infer_prop_type_from_docstring(self) -> None: 'Tuple[int, int]') assert_equal(infer_prop_type_from_docstring('\nstr: A string.'), None) + def test_remove_misplaced_type_comments_1(self) -> None: + good = """ + \u1234 + def f(x): # type: (int) -> int + + def g(x): + # type: (int) -> int + + def h(): + + # type: () int + + x = 1 # type: int + """ + + assert_equal(remove_misplaced_type_comments(good), good) + + def test_remove_misplaced_type_comments_2(self) -> None: + bad = """ + def f(x): + # type: Callable[[int], int] + pass + + # type: "foo" + # type: 'bar' + x = 1 + # type: int + """ + bad_fixed = """ + def f(x): + + pass + + + + x = 1 + + """ + assert_equal(remove_misplaced_type_comments(bad), bad_fixed) + + def test_remove_misplaced_type_comments_3(self) -> None: + bad = ''' + def f(x): + """docstring""" + # type: (int) -> int + pass + + def g(x): + """docstring + """ + # type: (int) -> int + pass + ''' + bad_fixed = ''' + def f(x): + """docstring""" + + pass + + def g(x): + """docstring + """ + + pass + ''' + assert_equal(remove_misplaced_type_comments(bad), bad_fixed) + + def test_remove_misplaced_type_comments_4(self) -> None: + bad = """ + def f(x): + '''docstring''' + # type: (int) -> int + pass + + def g(x): + '''docstring + ''' + # type: (int) -> int + pass + """ + bad_fixed = """ + def f(x): + '''docstring''' + + pass + + def g(x): + '''docstring + ''' + + pass + """ + assert_equal(remove_misplaced_type_comments(bad), bad_fixed) + + def test_remove_misplaced_type_comments_5(self) -> None: + bad = """ + def f(x): + # type: (int, List[Any], + # float, bool) -> int + pass + + def g(x): + # type: (int, List[Any]) + pass + """ + bad_fixed = """ + def f(x): + + # float, bool) -> int + pass + + def g(x): + + pass + """ + assert_equal(remove_misplaced_type_comments(bad), bad_fixed) + + def test_remove_misplaced_type_comments_bytes(self) -> None: + original = b""" + \xbf + def f(x): # type: (int) -> int + + def g(x): + # type: (int) -> int + pass + + def h(): + # type: int + pass + + x = 1 # type: int + """ + + dest = b""" + \xbf + def f(x): # type: (int) -> int + + def g(x): + # type: (int) -> int + pass + + def h(): + + pass + + x = 1 # type: int + """ + + assert_equal(remove_misplaced_type_comments(original), dest) + + def test_common_dir_prefix(self) -> None: + assert common_dir_prefix([]) == '.' + assert common_dir_prefix(['x.pyi']) == '.' + assert common_dir_prefix(['./x.pyi']) == '.' + assert common_dir_prefix(['foo/bar/x.pyi']) == 'foo/bar' + assert common_dir_prefix(['foo/bar/x.pyi', + 'foo/bar/y.pyi']) == 'foo/bar' + assert common_dir_prefix(['foo/bar/x.pyi', 'foo/y.pyi']) == 'foo' + assert common_dir_prefix(['foo/x.pyi', 'foo/bar/y.pyi']) == 'foo' + assert common_dir_prefix(['foo/bar/zar/x.pyi', 'foo/y.pyi']) == 'foo' + assert common_dir_prefix(['foo/x.pyi', 'foo/bar/zar/y.pyi']) == 'foo' + assert common_dir_prefix(['foo/bar/zar/x.pyi', 'foo/bar/y.pyi']) == 'foo/bar' + assert common_dir_prefix(['foo/bar/x.pyi', 'foo/bar/zar/y.pyi']) == 'foo/bar' + + +class StubgenHelpersSuite(Suite): + def test_is_blacklisted_path(self) -> None: + assert not is_blacklisted_path('foo/bar.py') + assert not is_blacklisted_path('foo.py') + assert not is_blacklisted_path('foo/xvendor/bar.py') + assert not is_blacklisted_path('foo/vendorx/bar.py') + assert is_blacklisted_path('foo/vendor/bar.py') + assert is_blacklisted_path('foo/vendored/bar.py') + assert is_blacklisted_path('foo/vendored/bar/thing.py') + assert is_blacklisted_path('foo/six.py') + + def test_is_test_module(self) -> None: + assert not is_test_module('foo') + assert not is_test_module('foo.bar') + + # The following could be test modules, but we are very conservative and + # don't treat them as such since they could plausibly be real modules. + assert not is_test_module('foo.bartest') + assert not is_test_module('foo.bartests') + assert not is_test_module('foo.testbar') + + assert is_test_module('foo.test') + assert is_test_module('foo.test.foo') + assert is_test_module('foo.tests') + assert is_test_module('foo.tests.foo') + assert is_test_module('foo.testing.foo') + + assert is_test_module('foo.test_bar') + assert is_test_module('foo.bar_tests') + assert is_test_module('foo.testing') + assert is_test_module('foo.conftest') + assert is_test_module('foo.bar_test_util') + assert is_test_module('foo.bar_test_utils') + assert is_test_module('foo.bar_test_base') + class StubgenPythonSuite(DataSuite): + """Data-driven end-to-end test cases that generate stub files. + + You can use these magic test case name suffixes: + + *_semanal + Run semantic analysis (slow as this uses real stubs -- only use + when necessary) + *_import + Import module and perform runtime introspection (in the current + process!) + + You can use these magic comments: + + # flags: --some-stubgen-option ... + Specify custom stubgen options + + # modules: module1 module2 ... + Specify which modules to output (by default only 'main') + """ + required_out_section = True base_path = '.' files = ['stubgen.test'] @@ -281,17 +512,21 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: self.run_case_inner(testcase) def run_case_inner(self, testcase: DataDrivenTestCase) -> None: - extra = [] - mods = [] + extra = [] # Extra command-line args + mods = [] # Module names to process source = '\n'.join(testcase.input) for file, content in testcase.files + [('./main.py', source)]: - mod = os.path.basename(file)[:-3] + # Strip ./ prefix and .py suffix. + mod = file[2:-3].replace('/', '.') + if mod.endswith('.__init__'): + mod, _, _ = mod.rpartition('.') mods.append(mod) extra.extend(['-m', mod]) with open(file, 'w') as f: f.write(content) options = self.parse_flags(source, extra) + modules = self.parse_modules(source) out_dir = 'out' try: try: @@ -299,9 +534,11 @@ def run_case_inner(self, testcase: DataDrivenTestCase) -> None: options.no_import = True if not testcase.name.endswith('_semanal'): options.parse_only = True - generate_stubs(options, quiet=True, add_header=False) + generate_stubs(options) a = [] # type: List[str] - self.add_file(os.path.join(out_dir, 'main.pyi'), a) + for module in modules: + fnam = module_to_path(out_dir, module) + self.add_file(fnam, a, header=len(modules) > 1) except CompileError as e: a = e.messages assert_string_arrays_equal(testcase.output, a, @@ -319,32 +556,57 @@ def parse_flags(self, program_text: str, extra: List[str]) -> Options: flag_list = flags.group(1).split() else: flag_list = [] - return parse_options(flag_list + extra) + options = parse_options(flag_list + extra) + if '--verbose' not in flag_list: + options.quiet = True + else: + options.verbose = True + return options - def add_file(self, path: str, result: List[str]) -> None: + def parse_modules(self, program_text: str) -> List[str]: + modules = re.search('# modules: (.*)$', program_text, flags=re.MULTILINE) + if modules: + return modules.group(1).split() + else: + return ['main'] + + def add_file(self, path: str, result: List[str], header: bool) -> None: + if not os.path.exists(path): + result.append('<%s was not generated>' % path.replace('\\', '/')) + return + if header: + result.append('# {}'.format(path[4:])) with open(path, encoding='utf8') as file: result.extend(file.read().splitlines()) +self_arg = ArgSig(name='self') + + class StubgencSuite(Suite): + """Unit tests for stub generation from C modules using introspection. + + Note that these don't cover a lot! + """ + def test_infer_hash_sig(self) -> None: - assert_equal(infer_method_sig('__hash__'), []) + assert_equal(infer_method_sig('__hash__'), [self_arg]) def test_infer_getitem_sig(self) -> None: - assert_equal(infer_method_sig('__getitem__'), [ArgSig(name='index')]) + assert_equal(infer_method_sig('__getitem__'), [self_arg, ArgSig(name='index')]) def test_infer_setitem_sig(self) -> None: assert_equal(infer_method_sig('__setitem__'), - [ArgSig(name='index'), ArgSig(name='object')]) + [self_arg, ArgSig(name='index'), ArgSig(name='object')]) def test_infer_binary_op_sig(self) -> None: for op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge', 'add', 'radd', 'sub', 'rsub', 'mul', 'rmul'): - assert_equal(infer_method_sig('__%s__' % op), [ArgSig(name='other')]) + assert_equal(infer_method_sig('__%s__' % op), [self_arg, ArgSig(name='other')]) def test_infer_unary_op_sig(self) -> None: for op in ('neg', 'pos'): - assert_equal(infer_method_sig('__%s__' % op), []) + assert_equal(infer_method_sig('__%s__' % op), [self_arg]) def test_generate_c_type_stub_no_crash_for_object(self) -> None: output = [] # type: List[str] @@ -535,3 +797,12 @@ def test_repr(self) -> None: "ArgSig(name='func', type='str', default=False)") assert_equal(repr(ArgSig("func", 'str', default=True)), "ArgSig(name='func', type='str', default=True)") + + +def module_to_path(out_dir: str, module: str) -> str: + fnam = os.path.join(out_dir, '{}.pyi'.format(module.replace('.', '/'))) + if not os.path.exists(fnam): + alt_fnam = fnam.replace('.pyi', '/__init__.pyi') + if os.path.exists(alt_fnam): + return alt_fnam + return fnam diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 9e134d95cb97..5d044524a05f 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -3715,7 +3715,7 @@ class A: pass class B(A): pass class C(B): pass class D(A, B): pass # E: Cannot determine consistent method resolution order (MRO) for "D" -class E(C, D): pass # E: Cannot determine consistent method resolution order (MRO) for "E" +class E(C, D): pass [case testInconsistentMroLocalRef] class A: pass diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index b95e34f3f0d4..18837e26dc79 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -1,3 +1,5 @@ +-- Test cases for stubgen that generate stubs from Python code + [case testEmptyFile] [out] @@ -580,8 +582,9 @@ from collections import namedtuple X = namedtuple('X', ['a', 'b']) [case testNamedtupleAltSyntax] -from collections import namedtuple, x +from collections import namedtuple, xx X = namedtuple('X', 'a b') +xx [out] from collections import namedtuple @@ -613,10 +616,11 @@ _X = namedtuple('_X', ['a', 'b']) class Y(_X): ... [case testNamedtupleAltSyntaxFieldsTuples] -from collections import namedtuple, x +from collections import namedtuple, xx X = namedtuple('X', ()) Y = namedtuple('Y', ('a',)) Z = namedtuple('Z', ('a', 'b', 'c', 'd', 'e')) +xx [out] from collections import namedtuple @@ -701,15 +705,16 @@ class A: [case testExportViaRelativeImport] from .api import get [out] -from .api import get +from .api import get as get [case testExportViaRelativePackageImport] from .packages.urllib3.contrib import parse [out] -from .packages.urllib3.contrib import parse +from .packages.urllib3.contrib import parse as parse [case testNoExportViaRelativeImport] from . import get +get() [out] [case testRelativeImportAndBase] @@ -754,35 +759,35 @@ class A: [case testAnnotationImportsFrom] import foo -from collection import defaultdict +from collections import defaultdict x: defaultdict [out] -from collection import defaultdict +from collections import defaultdict x: defaultdict [case testAnnotationImports] import foo -import collection -x: collection.defaultdict +import collections +x: collections.defaultdict [out] -import collection +import collections -x: collection.defaultdict +x: collections.defaultdict [case testAnnotationImports] from typing import List -import collection -x: List[collection.defaultdict] +import collections +x: List[collections.defaultdict] [out] -import collection +import collections from typing import List -x: List[collection.defaultdict] +x: List[collections.defaultdict] [case testAnnotationFwRefs] @@ -1427,7 +1432,7 @@ class A(metaclass=abc.ABCMeta): def meth(self): ... [case testABCMeta_semanal] -from base import base +from base import Base from abc import abstractmethod class C(Base): @@ -1445,6 +1450,7 @@ class Base(metaclass=ABCMeta): [out] import abc from abc import abstractmethod +from base import Base from typing import Any class C(Base, metaclass=abc.ABCMeta): @@ -1511,3 +1517,249 @@ class B: y: str = ... @x.setter def x(self, value: Any) -> None: ... + +[case testMisplacedTypeComment] +def f(): + x = 0 + + # type: str + y = '' + +[out] +def f() -> None: ... + +[case testConditionalImportAll_semanal] +__all__ = ['cookielib'] + +if object(): + from http import cookiejar as cookielib +else: + import cookielib + +[out] +import FIXME as cookielib + +[case testCannotCalculateMRO_semanal] +class X: pass + +class int(int, X): # Cycle + pass + +class A: pass +class B(A): pass +class C(B): pass +class D(A, B): pass # No consistent method resolution order +class E(C, D): pass # Ditto + +[out] +class X: ... +class int(int, X): ... +class A: ... +class B(A): ... +class C(B): ... +class D(A, B): ... +class E(C, D): ... + +[case testUnreachableCode_semanal] +MYPY = False +class A: pass +if MYPY: + class C(A): + def f(self) -> None: pass +else: + def f(i): + return i + + class C(A): + def g(self) -> None: pass +[out] +MYPY: bool + +class A: ... + +class C(A): + def f(self) -> None: ... + +[case testAbstractProperty1_semanal] +import other +import abc + +class A: + @abc.abstractproperty + def x(self): pass + +[out] +import abc +from typing import Any + +class A(metaclass=abc.ABCMeta): + @abc.abstractproperty + def x(self) -> Any: ... + +[case testAbstractProperty2_semanal] +import other +from abc import abstractproperty + +class A: + @abstractproperty + def x(self): pass + +[out] +import abc +from abc import abstractproperty +from typing import Any + +class A(metaclass=abc.ABCMeta): + @abstractproperty + def x(self) -> Any: ... + +[case testClassWithNameAnyOrOptional] +def f(x=object()): + return 1 + +def g(x=None): pass + +x = g() + +class Any: + pass + +def Optional(): + return 0 + +[out] +from typing import Any as _Any, Optional as _Optional + +def f(x: _Any = ...): ... +def g(x: _Optional[_Any] = ...) -> None: ... + +x: _Any + +class Any: ... + +def Optional(): ... + +[case testExportedNameImported] +# modules: main a b +from a import C + +class D(C): pass + +[file a.py] +from b import C + +[file b.py] +class C: pass + +[out] +# main.pyi +from a import C + +class D(C): ... +# a.pyi +from b import C as C +# b.pyi +class C: ... + +[case testVendoredSix] +from p1.vendored import six +from p1.vendor.six import foobar +from p1.packages.six.moves import http_client +from .packages.six.moves import queue + +class C(http_client.HTTPMessage): pass +class D(six.Iterator): pass + +[out] +import six +from six import foobar as foobar +from six.moves import http_client, queue as queue + +class C(http_client.HTTPMessage): ... +class D(six.Iterator): ... + +[case testVendoredPackage] +# modules: main p.vendored.requests p.sub.requests +from p.vendored.requests import Request +from p.sub.requests import Request2 + +x = Request() +y = Request2() + +[file p/__init__.py] + +[file p/vendored/__init__.py] + +[file p/vendored/requests.py] +class Request: + pass + +[file p/sub/__init__.py] + +[file p/sub/requests.py] +class Request2: + pass + +[out] +# main.pyi +from typing import Any + +x: Any +y: Any + +# p/sub/requests.pyi +class Request2: ... + +[case testTestFiles] +# modules: p p.x p.tests p.tests.test_foo + +[file p/__init__.py] +def f(): pass + +[file p/x.py] +def g(): pass + +[file p/tests/__init__.py] + +[file p/tests/test_foo.py] +def test_thing(): pass + +[out] +# p/__init__.pyi +def f() -> None: ... +# p/x.pyi +def g() -> None: ... + + + +[case testTestFiles_import] +# modules: p p.x p.tests p.tests.test_foo + +[file p/__init__.py] +def f(): pass + +[file p/x.py] +def g(): pass + +[file p/tests/__init__.py] + +[file p/tests/test_foo.py] +def test_thing(): pass + +[out] +# p/__init__.pyi +def f() -> None: ... +# p/x.pyi +def g() -> None: ... + + + +[case testVerboseFlag] +# Just test that --verbose does not break anything in a basic test case. +# flags: --verbose + +def f(x, y): pass +[out] +from typing import Any + +def f(x: Any, y: Any) -> None: ...