diff --git a/src/pip/_internal/cli/base_command.py b/src/pip/_internal/cli/base_command.py index c8e7fe4095c..694527c1d65 100644 --- a/src/pip/_internal/cli/base_command.py +++ b/src/pip/_internal/cli/base_command.py @@ -1,8 +1,5 @@ """Base Command class, and related routines""" -# The following comment should be removed at some point in the future. -# mypy: strict-optional=False - from __future__ import absolute_import, print_function import logging @@ -25,7 +22,6 @@ UNKNOWN_ERROR, VIRTUALENV_NOT_FOUND, ) -from pip._internal.download import PipSession from pip._internal.exceptions import ( BadCommand, CommandError, @@ -35,13 +31,12 @@ ) from pip._internal.utils.deprecation import deprecated from pip._internal.utils.logging import BrokenStdoutLoggingError, setup_logging -from pip._internal.utils.misc import get_prog, normalize_path -from pip._internal.utils.outdated import pip_version_check +from pip._internal.utils.misc import get_prog from pip._internal.utils.typing import MYPY_CHECK_RUNNING from pip._internal.utils.virtualenv import running_under_virtualenv if MYPY_CHECK_RUNNING: - from typing import Optional, List, Tuple, Any + from typing import List, Tuple, Any from optparse import Values __all__ = ['Command'] @@ -50,15 +45,14 @@ class Command(object): - name = None # type: Optional[str] - usage = None # type: Optional[str] + usage = None # type: str ignore_require_venv = False # type: bool def __init__(self, name, summary, isolated=False): # type: (str, str, bool) -> None parser_kw = { 'usage': self.usage, - 'prog': '%s %s' % (get_prog(), self.name), + 'prog': '%s %s' % (get_prog(), name), 'formatter': UpdatingDefaultsHelpFormatter(), 'add_help_option': False, 'name': name, @@ -81,62 +75,20 @@ def __init__(self, name, summary, isolated=False): ) self.parser.add_option_group(gen_opts) + def handle_pip_version_check(self, options): + # type: (Values) -> None + """ + This is a no-op so that commands by default do not do the pip version + check. + """ + # Make sure we do the pip version check if the index_group options + # are present. + assert not hasattr(options, 'no_index') + def run(self, options, args): # type: (Values, List[Any]) -> Any raise NotImplementedError - @classmethod - def _get_index_urls(cls, options): - """Return a list of index urls from user-provided options.""" - index_urls = [] - if not getattr(options, "no_index", False): - url = getattr(options, "index_url", None) - if url: - index_urls.append(url) - urls = getattr(options, "extra_index_urls", None) - if urls: - index_urls.extend(urls) - # Return None rather than an empty list - return index_urls or None - - def _build_session(self, options, retries=None, timeout=None): - # type: (Values, Optional[int], Optional[int]) -> PipSession - session = PipSession( - cache=( - normalize_path(os.path.join(options.cache_dir, "http")) - if options.cache_dir else None - ), - retries=retries if retries is not None else options.retries, - insecure_hosts=options.trusted_hosts, - index_urls=self._get_index_urls(options), - ) - - # Handle custom ca-bundles from the user - if options.cert: - session.verify = options.cert - - # Handle SSL client certificate - if options.client_cert: - session.cert = options.client_cert - - # Handle timeouts - if options.timeout or timeout: - session.timeout = ( - timeout if timeout is not None else options.timeout - ) - - # Handle configured proxies - if options.proxy: - session.proxies = { - "http": options.proxy, - "https": options.proxy, - } - - # Determine if we can prompt the user for authentication or not - session.auth.prompting = not options.no_input - - return session - def parse_args(self, args): # type: (List[str]) -> Tuple # factored out for testability @@ -226,21 +178,7 @@ def main(self, args): return UNKNOWN_ERROR finally: - allow_version_check = ( - # Does this command have the index_group options? - hasattr(options, "no_index") and - # Is this command allowed to perform this check? - not (options.disable_pip_version_check or options.no_index) - ) - # Check if we're using the latest version of pip available - if allow_version_check: - session = self._build_session( - options, - retries=0, - timeout=min(5, options.timeout) - ) - with session: - pip_version_check(session, options) + self.handle_pip_version_check(options) # Shutdown the logging module logging.shutdown() diff --git a/src/pip/_internal/cli/req_command.py b/src/pip/_internal/cli/req_command.py index 63776f523b1..cc1d392f9fe 100644 --- a/src/pip/_internal/cli/req_command.py +++ b/src/pip/_internal/cli/req_command.py @@ -1,12 +1,15 @@ -"""Contains the RequirementCommand base class. +"""Contains the Command base classes that depend on PipSession. -This is in a separate module so that Command classes not inheriting from -RequirementCommand don't need to import e.g. the PackageFinder machinery -and all its vendored dependencies. +The classes in this module are in a separate module so the commands not +needing download / PackageFinder capability don't unnecessarily import the +PackageFinder machinery and all its vendored dependencies, etc. """ +import os + from pip._internal.cli.base_command import Command from pip._internal.cli.cmdoptions import make_search_scope +from pip._internal.download import PipSession from pip._internal.exceptions import CommandError from pip._internal.index import PackageFinder from pip._internal.legacy_resolve import Resolver @@ -17,28 +20,119 @@ install_req_from_line, ) from pip._internal.req.req_file import parse_requirements +from pip._internal.utils.misc import normalize_path +from pip._internal.utils.outdated import pip_version_check from pip._internal.utils.typing import MYPY_CHECK_RUNNING if MYPY_CHECK_RUNNING: from optparse import Values from typing import List, Optional, Tuple from pip._internal.cache import WheelCache - from pip._internal.download import PipSession from pip._internal.models.target_python import TargetPython from pip._internal.req.req_set import RequirementSet from pip._internal.req.req_tracker import RequirementTracker from pip._internal.utils.temp_dir import TempDirectory -class RequirementCommand(Command): +class SessionCommandMixin(object): + + """ + A class mixin for command classes needing _build_session(). + """ + + @classmethod + def _get_index_urls(cls, options): + """Return a list of index urls from user-provided options.""" + index_urls = [] + if not getattr(options, "no_index", False): + url = getattr(options, "index_url", None) + if url: + index_urls.append(url) + urls = getattr(options, "extra_index_urls", None) + if urls: + index_urls.extend(urls) + # Return None rather than an empty list + return index_urls or None + + def _build_session(self, options, retries=None, timeout=None): + # type: (Values, Optional[int], Optional[int]) -> PipSession + session = PipSession( + cache=( + normalize_path(os.path.join(options.cache_dir, "http")) + if options.cache_dir else None + ), + retries=retries if retries is not None else options.retries, + insecure_hosts=options.trusted_hosts, + index_urls=self._get_index_urls(options), + ) + + # Handle custom ca-bundles from the user + if options.cert: + session.verify = options.cert + + # Handle SSL client certificate + if options.client_cert: + session.cert = options.client_cert + + # Handle timeouts + if options.timeout or timeout: + session.timeout = ( + timeout if timeout is not None else options.timeout + ) + + # Handle configured proxies + if options.proxy: + session.proxies = { + "http": options.proxy, + "https": options.proxy, + } + + # Determine if we can prompt the user for authentication or not + session.auth.prompting = not options.no_input + + return session + + +class IndexGroupCommand(SessionCommandMixin, Command): + + """ + Abstract base class for commands with the index_group options. + + This also corresponds to the commands that permit the pip version check. + """ + + def handle_pip_version_check(self, options): + # type: (Values) -> None + """ + Do the pip version check if not disabled. + + This overrides the default behavior of not doing the check. + """ + # Make sure the index_group options are present. + assert hasattr(options, 'no_index') + + if options.disable_pip_version_check or options.no_index: + return + + # Otherwise, check if we're using the latest version of pip available. + session = self._build_session( + options, + retries=0, + timeout=min(5, options.timeout) + ) + with session: + pip_version_check(session, options) + + +class RequirementCommand(IndexGroupCommand): @staticmethod def make_requirement_preparer( - temp_directory, # type: TempDirectory - options, # type: Values - req_tracker, # type: RequirementTracker - download_dir=None, # type: str - wheel_download_dir=None, # type: str + temp_directory, # type: TempDirectory + options, # type: Values + req_tracker, # type: RequirementTracker + download_dir=None, # type: str + wheel_download_dir=None, # type: str ): # type: (...) -> RequirementPreparer """ @@ -56,18 +150,18 @@ def make_requirement_preparer( @staticmethod def make_resolver( - preparer, # type: RequirementPreparer - session, # type: PipSession - finder, # type: PackageFinder - options, # type: Values - wheel_cache=None, # type: Optional[WheelCache] - use_user_site=False, # type: bool - ignore_installed=True, # type: bool - ignore_requires_python=False, # type: bool - force_reinstall=False, # type: bool - upgrade_strategy="to-satisfy-only", # type: str - use_pep517=None, # type: Optional[bool] - py_version_info=None # type: Optional[Tuple[int, ...]] + preparer, # type: RequirementPreparer + session, # type: PipSession + finder, # type: PackageFinder + options, # type: Values + wheel_cache=None, # type: Optional[WheelCache] + use_user_site=False, # type: bool + ignore_installed=True, # type: bool + ignore_requires_python=False, # type: bool + force_reinstall=False, # type: bool + upgrade_strategy="to-satisfy-only", # type: str + use_pep517=None, # type: Optional[bool] + py_version_info=None # type: Optional[Tuple[int, ...]] ): # type: (...) -> Resolver """ @@ -90,14 +184,15 @@ def make_resolver( ) @staticmethod - def populate_requirement_set(requirement_set, # type: RequirementSet - args, # type: List[str] - options, # type: Values - finder, # type: PackageFinder - session, # type: PipSession - name, # type: str - wheel_cache # type: Optional[WheelCache] - ): + def populate_requirement_set( + requirement_set, # type: RequirementSet + args, # type: List[str] + options, # type: Values + finder, # type: PackageFinder + session, # type: PipSession + name, # type: str + wheel_cache, # type: Optional[WheelCache] + ): # type: (...) -> None """ Marshal cmd line args into a requirement set. diff --git a/src/pip/_internal/commands/list.py b/src/pip/_internal/commands/list.py index 2fd39097c12..aacd5680ca1 100644 --- a/src/pip/_internal/commands/list.py +++ b/src/pip/_internal/commands/list.py @@ -7,8 +7,8 @@ from pip._vendor.six.moves import zip_longest from pip._internal.cli import cmdoptions -from pip._internal.cli.base_command import Command from pip._internal.cli.cmdoptions import make_search_scope +from pip._internal.cli.req_command import IndexGroupCommand from pip._internal.exceptions import CommandError from pip._internal.index import PackageFinder from pip._internal.models.selection_prefs import SelectionPreferences @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -class ListCommand(Command): +class ListCommand(IndexGroupCommand): """ List installed packages, including editables. diff --git a/src/pip/_internal/commands/search.py b/src/pip/_internal/commands/search.py index c96f0b90423..6889375e06d 100644 --- a/src/pip/_internal/commands/search.py +++ b/src/pip/_internal/commands/search.py @@ -12,6 +12,7 @@ from pip._vendor.six.moves import xmlrpc_client # type: ignore from pip._internal.cli.base_command import Command +from pip._internal.cli.req_command import SessionCommandMixin from pip._internal.cli.status_codes import NO_MATCHES_FOUND, SUCCESS from pip._internal.download import PipXmlrpcTransport from pip._internal.exceptions import CommandError @@ -22,7 +23,7 @@ logger = logging.getLogger(__name__) -class SearchCommand(Command): +class SearchCommand(SessionCommandMixin, Command): """Search for PyPI packages whose name or summary contains .""" usage = """ diff --git a/src/pip/_internal/commands/uninstall.py b/src/pip/_internal/commands/uninstall.py index ede23083857..6d72400e6b3 100644 --- a/src/pip/_internal/commands/uninstall.py +++ b/src/pip/_internal/commands/uninstall.py @@ -3,13 +3,14 @@ from pip._vendor.packaging.utils import canonicalize_name from pip._internal.cli.base_command import Command +from pip._internal.cli.req_command import SessionCommandMixin from pip._internal.exceptions import InstallationError from pip._internal.req import parse_requirements from pip._internal.req.constructors import install_req_from_line from pip._internal.utils.misc import protect_pip_from_modification_on_windows -class UninstallCommand(Command): +class UninstallCommand(SessionCommandMixin, Command): """ Uninstall packages. diff --git a/tests/unit/test_base_command.py b/tests/unit/test_base_command.py index fc6cf2b7a78..ba34b3922f8 100644 --- a/tests/unit/test_base_command.py +++ b/tests/unit/test_base_command.py @@ -2,6 +2,8 @@ import os import time +from mock import patch + from pip._internal.cli.base_command import Command from pip._internal.utils.logging import BrokenStdoutLoggingError @@ -72,6 +74,16 @@ def test_raise_broken_stdout__debug_logging(self, capsys): assert 'Traceback (most recent call last):' in stderr +@patch('pip._internal.cli.req_command.Command.handle_pip_version_check') +def test_handle_pip_version_check_called(mock_handle_version_check): + """ + Check that Command.handle_pip_version_check() is called. + """ + cmd = FakeCommand() + cmd.main([]) + mock_handle_version_check.assert_called_once() + + class Test_base_command_logging(object): """ Test `pip.base_command.Command` setting up logging consumers based on diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py index 324e322a3f8..be6c783524d 100644 --- a/tests/unit/test_commands.py +++ b/tests/unit/test_commands.py @@ -1,7 +1,26 @@ import pytest +from mock import patch +from pip._internal.cli.req_command import ( + IndexGroupCommand, + RequirementCommand, + SessionCommandMixin, +) from pip._internal.commands import commands_dict, create_command +# These are the expected names of the commands whose classes inherit from +# IndexGroupCommand. +EXPECTED_INDEX_GROUP_COMMANDS = ['download', 'install', 'list', 'wheel'] + + +def check_commands(pred, expected): + """ + Check the commands satisfying a predicate. + """ + commands = [create_command(name) for name in sorted(commands_dict)] + actual = [command.name for command in commands if pred(command)] + assert actual == expected, 'actual: {}'.format(actual) + def test_commands_dict__order(): """ @@ -20,3 +39,74 @@ def test_create_command(name): command = create_command(name) assert command.name == name assert command.summary == commands_dict[name].summary + + +def test_session_commands(): + """ + Test which commands inherit from SessionCommandMixin. + """ + def is_session_command(command): + return isinstance(command, SessionCommandMixin) + + expected = ['download', 'install', 'list', 'search', 'uninstall', 'wheel'] + check_commands(is_session_command, expected) + + +def test_index_group_commands(): + """ + Test the commands inheriting from IndexGroupCommand. + """ + def is_index_group_command(command): + return isinstance(command, IndexGroupCommand) + + check_commands(is_index_group_command, EXPECTED_INDEX_GROUP_COMMANDS) + + # Also check that the commands inheriting from IndexGroupCommand are + # exactly the commands with the --no-index option. + def has_option_no_index(command): + return command.parser.has_option('--no-index') + + check_commands(has_option_no_index, EXPECTED_INDEX_GROUP_COMMANDS) + + +@pytest.mark.parametrize('command_name', EXPECTED_INDEX_GROUP_COMMANDS) +@pytest.mark.parametrize( + 'disable_pip_version_check, no_index, expected_called', + [ + # pip_version_check() is only called when both + # disable_pip_version_check and no_index are False. + (False, False, True), + (False, True, False), + (True, False, False), + (True, True, False), + ], +) +@patch('pip._internal.cli.req_command.pip_version_check') +def test_index_group_handle_pip_version_check( + mock_version_check, command_name, disable_pip_version_check, no_index, + expected_called, +): + """ + Test whether pip_version_check() is called when handle_pip_version_check() + is called, for each of the IndexGroupCommand classes. + """ + command = create_command(command_name) + options = command.parser.get_default_values() + options.disable_pip_version_check = disable_pip_version_check + options.no_index = no_index + + command.handle_pip_version_check(options) + if expected_called: + mock_version_check.assert_called_once() + else: + mock_version_check.assert_not_called() + + +def test_requirement_commands(): + """ + Test which commands inherit from RequirementCommand. + """ + def is_requirement_command(command): + return isinstance(command, RequirementCommand) + + check_commands(is_requirement_command, ['download', 'install', 'wheel'])