Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 111 additions & 10 deletions mypy/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""Mypy type checker command line tool."""

import argparse
from gettext import gettext
import os
import subprocess
import sys
import time

from typing import Any, Dict, List, Optional, Tuple, TextIO
from typing import Any, Dict, IO, List, Optional, Sequence, Tuple, TextIO, Union
try:
from typing import NoReturn
except ImportError: # Python 3.5.1
NoReturn = None # type: ignore
from typing_extensions import Final

from mypy import build
Expand Down Expand Up @@ -252,6 +257,99 @@ def infer_python_executable(options: Options,
Define MYPY_CACHE_DIR to override configuration cache_dir path.""" # type: Final


class CapturableArgumentParser(argparse.ArgumentParser):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring that explains what this does and why we need to do this.


"""Override ArgumentParser methods that use sys.stdout/sys.stderr directly.

This is needed because hijacking sys.std* is not thread-safe,
yet output must be captured to properly support mypy.api.run.
"""

def __init__(self, *args: Any, **kwargs: Any):
self.stdout = kwargs.pop('stdout', sys.stdout)
self.stderr = kwargs.pop('stderr', sys.stderr)
super().__init__(*args, **kwargs)

# =====================
# Help-printing methods
# =====================
def print_usage(self, file: Optional[IO[str]] = None) -> None:
if file is None:
file = self.stdout
self._print_message(self.format_usage(), file)

def print_help(self, file: Optional[IO[str]] = None) -> None:
if file is None:
file = self.stdout
self._print_message(self.format_help(), file)

def _print_message(self, message: str, file: Optional[IO[str]] = None) -> None:
if message:
if file is None:
file = self.stderr
file.write(message)

# ===============
# Exiting methods
# ===============
def exit(self, status: int = 0, message: Optional[str] = None) -> NoReturn:
if message:
self._print_message(message, self.stderr)
sys.exit(status)

def error(self, message: str) -> NoReturn:
"""error(message: string)

Prints a usage message incorporating the message to stderr and
exits.

If you override this in a subclass, it should not return -- it
should either exit or raise an exception.
"""
self.print_usage(self.stderr)
args = {'prog': self.prog, 'message': message}
self.exit(2, gettext('%(prog)s: error: %(message)s\n') % args)


class CapturableVersionAction(argparse.Action):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docstring (maybe refer to the docstring that I requested for CapturableArgumentParser).


"""Supplement CapturableArgumentParser to handle --version.

This is nearly identical to argparse._VersionAction except,
like CapturableArgumentParser, it allows output to be captured.

Another notable difference is that version is mandatory.
This allows removing a line in __call__ that falls back to parser.version
(which does not appear to exist).
"""

def __init__(self,
option_strings: Sequence[str],
version: str,
dest: str = argparse.SUPPRESS,
default: str = argparse.SUPPRESS,
help: str = "show program's version number and exit",
stdout: Optional[IO[str]] = None):
super().__init__(
option_strings=option_strings,
dest=dest,
default=default,
nargs=0,
help=help)
self.version = version
self.stdout = stdout or sys.stdout

def __call__(self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Union[str, Sequence[Any], None],
option_string: Optional[str] = None) -> NoReturn:
formatter = parser._get_formatter()
formatter.add_text(self.version)
parser._print_message(formatter.format_help(), self.stdout)
parser.exit()


def process_options(args: List[str],
stdout: Optional[TextIO] = None,
stderr: Optional[TextIO] = None,
Expand All @@ -269,13 +367,15 @@ def process_options(args: List[str],
stdout = stdout or sys.stdout
stderr = stderr or sys.stderr

parser = argparse.ArgumentParser(prog=program,
usage=header,
description=DESCRIPTION,
epilog=FOOTER,
fromfile_prefix_chars='@',
formatter_class=AugmentedHelpFormatter,
add_help=False)
parser = CapturableArgumentParser(prog=program,
usage=header,
description=DESCRIPTION,
epilog=FOOTER,
fromfile_prefix_chars='@',
formatter_class=AugmentedHelpFormatter,
add_help=False,
stdout=stdout,
stderr=stderr)

strict_flag_names = [] # type: List[str]
strict_flag_assignments = [] # type: List[Tuple[str, bool]]
Expand Down Expand Up @@ -328,9 +428,10 @@ def add_invertible_flag(flag: str,
'-v', '--verbose', action='count', dest='verbosity',
help="More verbose messages")
general_group.add_argument(
'-V', '--version', action='version',
'-V', '--version', action=CapturableVersionAction,
version='%(prog)s ' + __version__,
help="Show program's version number and exit")
help="Show program's version number and exit",
stdout=stdout)

config_group = parser.add_argument_group(
title='Config file',
Expand Down
45 changes: 45 additions & 0 deletions mypy/test/testapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from io import StringIO
import sys

import mypy.api

from mypy.test.helpers import Suite


class APISuite(Suite):

def setUp(self) -> None:
self.sys_stdout = sys.stdout
self.sys_stderr = sys.stderr
sys.stdout = self.stdout = StringIO()
sys.stderr = self.stderr = StringIO()

def tearDown(self) -> None:
sys.stdout = self.sys_stdout
sys.stderr = self.sys_stderr
assert self.stdout.getvalue() == ''
assert self.stderr.getvalue() == ''

def test_capture_bad_opt(self) -> None:
"""stderr should be captured when a bad option is passed."""
_, stderr, _ = mypy.api.run(['--some-bad-option'])
assert isinstance(stderr, str)
assert stderr != ''
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test that the return value is a string, as otherwise the inequality will be trivially true. This might happen if it's None, at least. Maybe like this: assert isinstance(stderr, str) and stderr != ''


def test_capture_empty(self) -> None:
"""stderr should be captured when a bad option is passed."""
_, stderr, _ = mypy.api.run([])
assert isinstance(stderr, str)
assert stderr != ''
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above.


def test_capture_help(self) -> None:
"""stdout should be captured when --help is passed."""
stdout, _, _ = mypy.api.run(['--help'])
assert isinstance(stdout, str)
assert stdout != ''
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above.


def test_capture_version(self) -> None:
"""stdout should be captured when --version is passed."""
stdout, _, _ = mypy.api.run(['--version'])
assert isinstance(stdout, str)
assert stdout != ''
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above.