diff --git a/docs/source/changes.md b/docs/source/changes.md index 8e44dec1..31eef098 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -25,6 +25,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and - {pull}`536` allows partialed functions to be task functions. - {pull}`540` changes the CLI entry-point and allow `pytask.build(tasks=task_func)` as the signatures suggested. +- {pull}`542` refactors the plugin manager. ## 0.4.4 - 2023-12-04 diff --git a/pyproject.toml b/pyproject.toml index dce6e6ca..d9d00557 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,7 +146,7 @@ ignore = [ convention = "numpy" [tool.pytest.ini_options] -testpaths = ["tests"] +testpaths = ["src", "tests"] markers = [ "wip: Tests that are work-in-progress.", "unit: Flag for unit tests which target mainly a single function.", diff --git a/src/_pytask/build.py b/src/_pytask/build.py index 85e1c883..45444992 100644 --- a/src/_pytask/build.py +++ b/src/_pytask/build.py @@ -15,7 +15,6 @@ from _pytask.capture_utils import CaptureMethod from _pytask.capture_utils import ShowCapture from _pytask.click import ColoredCommand -from _pytask.config import hookimpl from _pytask.config_utils import find_project_root_and_config from _pytask.config_utils import read_config from _pytask.console import console @@ -26,6 +25,8 @@ from _pytask.outcomes import ExitCode from _pytask.path import HashPathCache from _pytask.pluginmanager import get_plugin_manager +from _pytask.pluginmanager import hookimpl +from _pytask.pluginmanager import storage from _pytask.session import Session from _pytask.shared import parse_paths from _pytask.shared import to_list @@ -62,7 +63,7 @@ def pytask_unconfigure(session: Session) -> None: path.write_text(json.dumps(HashPathCache._cache)) -def build( # noqa: C901, PLR0912, PLR0913 +def build( # noqa: C901, PLR0912, PLR0913, PLR0915 *, capture: Literal["fd", "no", "sys", "tee-sys"] | CaptureMethod = CaptureMethod.FD, check_casing_of_paths: bool = True, @@ -177,8 +178,6 @@ def build( # noqa: C901, PLR0912, PLR0913 """ try: - pm = get_plugin_manager() - raw_config = { "capture": capture, "check_casing_of_paths": check_casing_of_paths, @@ -212,6 +211,12 @@ def build( # noqa: C901, PLR0912, PLR0913 **kwargs, } + if "command" not in raw_config: + pm = get_plugin_manager() + storage.store(pm) + else: + pm = storage.get() + # If someone called the programmatic interface, we need to do some parsing. if "command" not in raw_config: raw_config["command"] = "build" diff --git a/src/_pytask/capture.py b/src/_pytask/capture.py index fa46d1ab..7d448e70 100644 --- a/src/_pytask/capture.py +++ b/src/_pytask/capture.py @@ -44,7 +44,7 @@ from _pytask.capture_utils import CaptureMethod from _pytask.capture_utils import ShowCapture from _pytask.click import EnumChoice -from _pytask.config import hookimpl +from _pytask.pluginmanager import hookimpl from _pytask.shared import convert_to_enum if TYPE_CHECKING: diff --git a/src/_pytask/clean.py b/src/_pytask/clean.py index 15ea6add..61e9ee08 100644 --- a/src/_pytask/clean.py +++ b/src/_pytask/clean.py @@ -14,7 +14,6 @@ import click from _pytask.click import ColoredCommand from _pytask.click import EnumChoice -from _pytask.config import hookimpl from _pytask.console import console from _pytask.exceptions import CollectionError from _pytask.git import get_all_files @@ -26,7 +25,8 @@ from _pytask.outcomes import ExitCode from _pytask.path import find_common_ancestor from _pytask.path import relative_to -from _pytask.pluginmanager import get_plugin_manager +from _pytask.pluginmanager import hookimpl +from _pytask.pluginmanager import storage from _pytask.session import Session from _pytask.shared import to_list from _pytask.traceback import Traceback @@ -97,12 +97,11 @@ def pytask_parse_config(config: dict[str, Any]) -> None: ) def clean(**raw_config: Any) -> NoReturn: # noqa: C901, PLR0912 """Clean the provided paths by removing files unknown to pytask.""" + pm = storage.get() raw_config["command"] = "clean" try: # Duplication of the same mechanism in :func:`pytask.build`. - pm = get_plugin_manager() - config = pm.hook.pytask_configure(pm=pm, raw_config=raw_config) session = Session.from_config(config) diff --git a/src/_pytask/cli.py b/src/_pytask/cli.py index f056b2e1..f6d6b450 100644 --- a/src/_pytask/cli.py +++ b/src/_pytask/cli.py @@ -2,17 +2,12 @@ from __future__ import annotations from typing import Any -from typing import TYPE_CHECKING import click from _pytask.click import ColoredGroup -from _pytask.config import hookimpl -from _pytask.pluginmanager import get_plugin_manager +from _pytask.pluginmanager import storage from packaging.version import parse as parse_version -if TYPE_CHECKING: - import pluggy - _CONTEXT_SETTINGS: dict[str, Any] = { "help_option_names": ("-h", "--help"), @@ -20,16 +15,16 @@ } -if parse_version(click.__version__) < parse_version("8"): # pragma: no cover - _VERSION_OPTION_KWARGS: dict[str, Any] = {} -else: # pragma: no cover +if parse_version(click.__version__) >= parse_version("8"): # pragma: no cover _VERSION_OPTION_KWARGS = {"package_name": "pytask"} +else: # pragma: no cover + _VERSION_OPTION_KWARGS = {} def _extend_command_line_interface(cli: click.Group) -> click.Group: """Add parameters from plugins to the commandline interface.""" - pm = get_plugin_manager() - pm.hook.pytask_extend_command_line_interface(cli=cli) + pm = storage.create() + pm.hook.pytask_extend_command_line_interface.call_historic(kwargs={"cli": cli}) _sort_options_for_each_command_alphabetically(cli) return cli @@ -42,54 +37,6 @@ def _sort_options_for_each_command_alphabetically(cli: click.Group) -> None: ) -@hookimpl -def pytask_add_hooks(pm: pluggy.PluginManager) -> None: - """Add hooks.""" - from _pytask import build - from _pytask import capture - from _pytask import clean - from _pytask import collect - from _pytask import collect_command - from _pytask import config - from _pytask import database - from _pytask import debugging - from _pytask import execute - from _pytask import dag_command - from _pytask import live - from _pytask import logging - from _pytask import mark - from _pytask import nodes - from _pytask import parameters - from _pytask import persist - from _pytask import profile - from _pytask import dag - from _pytask import skipping - from _pytask import task - from _pytask import warnings - - pm.register(build) - pm.register(capture) - pm.register(clean) - pm.register(collect) - pm.register(collect_command) - pm.register(config) - pm.register(database) - pm.register(debugging) - pm.register(execute) - pm.register(dag_command) - pm.register(live) - pm.register(logging) - pm.register(mark) - pm.register(nodes) - pm.register(parameters) - pm.register(persist) - pm.register(profile) - pm.register(dag) - pm.register(skipping) - pm.register(task) - pm.register(warnings) - - @click.group( cls=ColoredGroup, context_settings=_CONTEXT_SETTINGS, diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index c04335c3..3b262568 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -15,7 +15,6 @@ from _pytask.collect_utils import create_name_of_python_node from _pytask.collect_utils import parse_dependencies_from_task_function from _pytask.collect_utils import parse_products_from_task_function -from _pytask.config import hookimpl from _pytask.config import IS_FILE_SYSTEM_CASE_SENSITIVE from _pytask.console import console from _pytask.console import create_summary_panel @@ -36,6 +35,7 @@ from _pytask.path import find_case_sensitive_path from _pytask.path import import_path from _pytask.path import shorten_path +from _pytask.pluginmanager import hookimpl from _pytask.reports import CollectionReport from _pytask.shared import find_duplicates from _pytask.shared import to_list diff --git a/src/_pytask/collect_command.py b/src/_pytask/collect_command.py index 0ae8648d..f9181643 100644 --- a/src/_pytask/collect_command.py +++ b/src/_pytask/collect_command.py @@ -8,7 +8,6 @@ import click from _pytask.click import ColoredCommand -from _pytask.config import hookimpl from _pytask.console import console from _pytask.console import create_url_style_for_path from _pytask.console import FILE_ICON @@ -27,7 +26,8 @@ from _pytask.outcomes import ExitCode from _pytask.path import find_common_ancestor from _pytask.path import relative_to -from _pytask.pluginmanager import get_plugin_manager +from _pytask.pluginmanager import hookimpl +from _pytask.pluginmanager import storage from _pytask.session import Session from _pytask.tree_util import tree_leaves from rich.text import Text @@ -54,10 +54,10 @@ def pytask_extend_command_line_interface(cli: click.Group) -> None: ) def collect(**raw_config: Any | None) -> NoReturn: """Collect tasks and report information about them.""" + pm = storage.get() raw_config["command"] = "collect" try: - pm = get_plugin_manager() config = pm.hook.pytask_configure(pm=pm, raw_config=raw_config) session = Session.from_config(config) diff --git a/src/_pytask/compat.py b/src/_pytask/compat.py index c5213c3d..d392576e 100644 --- a/src/_pytask/compat.py +++ b/src/_pytask/compat.py @@ -1,10 +1,10 @@ """Contains functions to assess compatibility and optional dependencies.""" from __future__ import annotations -import importlib import shutil import sys import warnings +from importlib import import_module from typing import TYPE_CHECKING from packaging.version import parse as parse_version @@ -89,7 +89,9 @@ def import_optional_dependency( f"Use pip or conda to install {install_name!r}." ) try: - module = importlib.import_module(name) + # The from import is used to avoid monkeypatching errors in some tests. See + # https://stackoverflow.com/a/31746577 for more information. + module = import_module(name) except ImportError: if errors == "raise": raise ImportError(msg) from None diff --git a/src/_pytask/config.py b/src/_pytask/config.py index 7cb6d9e6..ef176100 100644 --- a/src/_pytask/config.py +++ b/src/_pytask/config.py @@ -4,14 +4,15 @@ import tempfile from pathlib import Path from typing import Any +from typing import TYPE_CHECKING -import pluggy +from _pytask.pluginmanager import hookimpl from _pytask.shared import parse_markers from _pytask.shared import parse_paths from _pytask.shared import to_list - -hookimpl = pluggy.HookimplMarker("pytask") +if TYPE_CHECKING: + from pluggy import PluginManager _IGNORED_FOLDERS: list[str] = [".git/*", ".venv/*"] @@ -59,9 +60,7 @@ def is_file_system_case_sensitive() -> bool: @hookimpl -def pytask_configure( - pm: pluggy.PluginManager, raw_config: dict[str, Any] -) -> dict[str, Any]: +def pytask_configure(pm: PluginManager, raw_config: dict[str, Any]) -> dict[str, Any]: """Configure pytask.""" # Add all values by default so that many plugins do not need to copy over values. config = {"pm": pm, "markers": {}, **raw_config} diff --git a/src/_pytask/dag.py b/src/_pytask/dag.py index 1278b8aa..89941b74 100644 --- a/src/_pytask/dag.py +++ b/src/_pytask/dag.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING import networkx as nx -from _pytask.config import hookimpl from _pytask.console import ARROW_DOWN_ICON from _pytask.console import console from _pytask.console import FILE_ICON @@ -19,6 +18,7 @@ from _pytask.node_protocols import PNode from _pytask.node_protocols import PTask from _pytask.nodes import PythonNode +from _pytask.pluginmanager import hookimpl from _pytask.reports import DagReport from _pytask.shared import reduce_names_of_multiple_nodes from _pytask.tree_util import tree_map diff --git a/src/_pytask/dag_command.py b/src/_pytask/dag_command.py index 5eeef7da..1dbcdb92 100644 --- a/src/_pytask/dag_command.py +++ b/src/_pytask/dag_command.py @@ -12,7 +12,6 @@ from _pytask.click import EnumChoice from _pytask.compat import check_for_optional_program from _pytask.compat import import_optional_dependency -from _pytask.config import hookimpl from _pytask.config_utils import find_project_root_and_config from _pytask.config_utils import read_config from _pytask.console import console @@ -21,6 +20,8 @@ from _pytask.exceptions import ResolvingDependenciesError from _pytask.outcomes import ExitCode from _pytask.pluginmanager import get_plugin_manager +from _pytask.pluginmanager import hookimpl +from _pytask.pluginmanager import storage from _pytask.session import Session from _pytask.shared import parse_paths from _pytask.shared import reduce_names_of_multiple_nodes @@ -80,7 +81,7 @@ def pytask_extend_command_line_interface(cli: click.Group) -> None: def dag(**raw_config: Any) -> int: """Create a visualization of the project's directed acyclic graph.""" try: - pm = get_plugin_manager() + pm = storage.get() config = pm.hook.pytask_configure(pm=pm, raw_config=raw_config) session = Session.from_config(config) @@ -143,6 +144,7 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph: """ try: pm = get_plugin_manager() + storage.store(pm) # If someone called the programmatic interface, we need to do some parsing. if "command" not in raw_config: diff --git a/src/_pytask/data_catalog.py b/src/_pytask/data_catalog.py index a25f1ae7..20b63b56 100644 --- a/src/_pytask/data_catalog.py +++ b/src/_pytask/data_catalog.py @@ -9,6 +9,7 @@ import inspect import pickle from pathlib import Path +from typing import Any from _pytask.config_utils import find_project_root_and_config from _pytask.exceptions import NodeNotCollectedError @@ -16,7 +17,7 @@ from _pytask.node_protocols import PNode from _pytask.node_protocols import PPathNode from _pytask.nodes import PickleNode -from _pytask.pluginmanager import get_plugin_manager +from _pytask.pluginmanager import storage from _pytask.session import Session from attrs import define from attrs import field @@ -34,13 +35,6 @@ def _get_parent_path_of_data_catalog_module(stacklevel: int = 2) -> Path: return Path.cwd() -def _create_default_session() -> Session: - """Create a default session to use the hooks and collect nodes.""" - return Session( - config={"check_casing_of_paths": True}, hook=get_plugin_manager().hook - ) - - @define(kw_only=True) class DataCatalog: """A data catalog. @@ -67,12 +61,14 @@ class DataCatalog: entries: dict[str, PNode] = field(factory=dict) name: str = "default" path: Path | None = None - _session: Session = field(factory=_create_default_session) + _session_config: dict[str, Any] = field( + factory=lambda *x: {"check_casing_of_paths": True} # noqa: ARG005 + ) _instance_path: Path = field(factory=_get_parent_path_of_data_catalog_module) def __attrs_post_init__(self) -> None: root_path, _ = find_project_root_and_config((self._instance_path,)) - self._session.config["paths"] = (root_path,) + self._session_config["paths"] = (root_path,) if not self.path: self.path = root_path / ".pytask" / "data_catalogs" / self.name @@ -115,8 +111,10 @@ def add(self, name: str, node: PNode | None = None) -> None: elif isinstance(node, PNode): self.entries[name] = node else: - collected_node = self._session.hook.pytask_collect_node( - session=self._session, + # Acquire the latest pluginmanager. + session = Session(config=self._session_config, hook=storage.get().hook) + collected_node = session.hook.pytask_collect_node( + session=session, path=self._instance_path, node_info=NodeInfo( arg_name=name, path=(), value=node, task_path=None, task_name="" diff --git a/src/_pytask/database.py b/src/_pytask/database.py index eb40fd2b..2a19a459 100644 --- a/src/_pytask/database.py +++ b/src/_pytask/database.py @@ -4,8 +4,8 @@ from pathlib import Path from typing import Any -from _pytask.config import hookimpl from _pytask.database_utils import create_database +from _pytask.pluginmanager import hookimpl from sqlalchemy.engine import make_url diff --git a/src/_pytask/debugging.py b/src/_pytask/debugging.py index 7da0d985..c95f79d3 100644 --- a/src/_pytask/debugging.py +++ b/src/_pytask/debugging.py @@ -10,15 +10,14 @@ from typing import TYPE_CHECKING import click -from _pytask.config import hookimpl from _pytask.console import console from _pytask.node_protocols import PTask from _pytask.outcomes import Exit +from _pytask.pluginmanager import hookimpl from _pytask.traceback import Traceback - if TYPE_CHECKING: - import pluggy + from pluggy import PluginManager from _pytask.session import Session from types import TracebackType from types import FrameType @@ -111,7 +110,7 @@ def pytask_unconfigure() -> None: class PytaskPDB: """Pseudo PDB that defers to the real pdb.""" - _pluginmanager: pluggy.PluginManager | None = None + _pluginmanager: PluginManager | None = None _config: dict[str, Any] | None = None _saved: ClassVar[list[tuple[Any, ...]]] = [] _recursive_debug: int = 0 diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index ccc87d66..e58f3173 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -7,7 +7,6 @@ from typing import Any from typing import TYPE_CHECKING -from _pytask.config import hookimpl from _pytask.config import IS_FILE_SYSTEM_CASE_SENSITIVE from _pytask.console import console from _pytask.console import create_summary_panel @@ -33,6 +32,7 @@ from _pytask.outcomes import SkippedUnchanged from _pytask.outcomes import TaskOutcome from _pytask.outcomes import WouldBeExecuted +from _pytask.pluginmanager import hookimpl from _pytask.reports import ExecutionReport from _pytask.traceback import remove_traceback_from_exc_info from _pytask.tree_util import tree_leaves diff --git a/src/_pytask/hookspecs.py b/src/_pytask/hookspecs.py index bea498e4..a1d34a69 100644 --- a/src/_pytask/hookspecs.py +++ b/src/_pytask/hookspecs.py @@ -25,13 +25,14 @@ from _pytask.reports import CollectionReport from _pytask.reports import ExecutionReport from _pytask.reports import DagReport + from pluggy import PluginManager hookspec = pluggy.HookspecMarker("pytask") -@hookspec -def pytask_add_hooks(pm: pluggy.PluginManager) -> None: +@hookspec(historic=True) +def pytask_add_hooks(pm: PluginManager) -> None: """Add hook specifications and implementations to the plugin manager. This hook is the first to be called to let plugins register their hook @@ -46,7 +47,7 @@ def pytask_add_hooks(pm: pluggy.PluginManager) -> None: # Hooks for the command-line interface. -@hookspec +@hookspec(historic=True) def pytask_extend_command_line_interface(cli: click.Group) -> None: """Extend the command line interface. @@ -65,9 +66,7 @@ def pytask_extend_command_line_interface(cli: click.Group) -> None: @hookspec(firstresult=True) -def pytask_configure( - pm: pluggy.PluginManager, raw_config: dict[str, Any] -) -> dict[str, Any]: +def pytask_configure(pm: PluginManager, raw_config: dict[str, Any]) -> dict[str, Any]: """Configure pytask. The main hook implementation which controls the configuration and calls subordinated diff --git a/src/_pytask/live.py b/src/_pytask/live.py index 32974426..77982d59 100644 --- a/src/_pytask/live.py +++ b/src/_pytask/live.py @@ -7,11 +7,11 @@ from typing import TYPE_CHECKING import click -from _pytask.config import hookimpl from _pytask.console import console from _pytask.console import format_task_name from _pytask.outcomes import CollectionOutcome from _pytask.outcomes import TaskOutcome +from _pytask.pluginmanager import hookimpl from attrs import define from attrs import field from rich.box import ROUNDED diff --git a/src/_pytask/logging.py b/src/_pytask/logging.py index 14332917..6dedd1c4 100644 --- a/src/_pytask/logging.py +++ b/src/_pytask/logging.py @@ -12,9 +12,9 @@ import _pytask import click import pluggy -from _pytask.config import hookimpl from _pytask.console import console from _pytask.console import IS_WINDOWS_TERMINAL +from _pytask.pluginmanager import hookimpl from _pytask.reports import ExecutionReport from _pytask.traceback import Traceback from rich.text import Text diff --git a/src/_pytask/mark/__init__.py b/src/_pytask/mark/__init__.py index 20266a13..22055cea 100644 --- a/src/_pytask/mark/__init__.py +++ b/src/_pytask/mark/__init__.py @@ -8,7 +8,6 @@ import click from _pytask.click import ColoredCommand -from _pytask.config import hookimpl from _pytask.console import console from _pytask.dag_utils import task_and_preceding_tasks from _pytask.exceptions import ConfigurationError @@ -19,7 +18,8 @@ from _pytask.mark.structures import MarkDecorator from _pytask.mark.structures import MarkGenerator from _pytask.outcomes import ExitCode -from _pytask.pluginmanager import get_plugin_manager +from _pytask.pluginmanager import hookimpl +from _pytask.pluginmanager import storage from _pytask.session import Session from _pytask.shared import parse_markers from attrs import define @@ -49,9 +49,9 @@ def markers(**raw_config: Any) -> NoReturn: """Show all registered markers.""" raw_config["command"] = "markers" + pm = storage.get() try: - pm = get_plugin_manager() config = pm.hook.pytask_configure(pm=pm, raw_config=raw_config) session = Session.from_config(config) diff --git a/src/_pytask/parameters.py b/src/_pytask/parameters.py index ddf7f907..36e547fe 100644 --- a/src/_pytask/parameters.py +++ b/src/_pytask/parameters.py @@ -4,8 +4,8 @@ from pathlib import Path import click -from _pytask.config import hookimpl from _pytask.config_utils import set_defaults_from_config +from _pytask.pluginmanager import hookimpl from click import Context from sqlalchemy.engine import make_url from sqlalchemy.engine import URL diff --git a/src/_pytask/persist.py b/src/_pytask/persist.py index 49b61510..bb74d3bb 100644 --- a/src/_pytask/persist.py +++ b/src/_pytask/persist.py @@ -4,13 +4,13 @@ from typing import Any from typing import TYPE_CHECKING -from _pytask.config import hookimpl from _pytask.dag_utils import node_and_neighbors from _pytask.database_utils import has_node_changed from _pytask.database_utils import update_states_in_database from _pytask.mark_utils import has_mark from _pytask.outcomes import Persisted from _pytask.outcomes import TaskOutcome +from _pytask.pluginmanager import hookimpl if TYPE_CHECKING: diff --git a/src/_pytask/pluginmanager.py b/src/_pytask/pluginmanager.py index aa3acac7..931e12a8 100644 --- a/src/_pytask/pluginmanager.py +++ b/src/_pytask/pluginmanager.py @@ -1,19 +1,109 @@ """Contains the plugin manager.""" from __future__ import annotations -import pluggy +import importlib +import sys +from typing import Iterable + from _pytask import hookspecs +from attrs import define +from pluggy import HookimplMarker +from pluggy import PluginManager + + +__all__ = [ + "get_plugin_manager", + "hookimpl", + "register_hook_impls_from_modules", + "storage", +] + + +hookimpl = HookimplMarker("pytask") + +def register_hook_impls_from_modules( + plugin_manager: PluginManager, module_names: Iterable[str] +) -> None: + """Register hook implementations from modules.""" + for module_name in module_names: + module = importlib.import_module(module_name) + plugin_manager.register(module) -def get_plugin_manager() -> pluggy.PluginManager: + +@hookimpl +def pytask_add_hooks(pm: PluginManager) -> None: + """Add hooks.""" + builtin_hook_impl_modules = ( + "_pytask.build", + "_pytask.capture", + "_pytask.clean", + "_pytask.collect", + "_pytask.collect_command", + "_pytask.config", + "_pytask.dag", + "_pytask.dag_command", + "_pytask.database", + "_pytask.debugging", + "_pytask.execute", + "_pytask.live", + "_pytask.logging", + "_pytask.mark", + "_pytask.nodes", + "_pytask.parameters", + "_pytask.persist", + "_pytask.profile", + "_pytask.skipping", + "_pytask.task", + "_pytask.warnings", + ) + register_hook_impls_from_modules(pm, builtin_hook_impl_modules) + + +def get_plugin_manager() -> PluginManager: """Get the plugin manager.""" - pm = pluggy.PluginManager("pytask") + pm = PluginManager("pytask") pm.add_hookspecs(hookspecs) pm.load_setuptools_entrypoints("pytask") - from _pytask import cli - - pm.register(cli) - pm.hook.pytask_add_hooks(pm=pm) + pm.register(sys.modules[__name__]) + pm.hook.pytask_add_hooks.call_historic(kwargs={"pm": pm}) return pm + + +@define +class _PluginManagerStorage: + """A class to store the plugin manager. + + This storage is needed to harmonize the two different ways to call pytask, via the + CLI or the API. + + When pytask is called from the CLI, the plugin manager is created in + :mod:`_pytask.cli` outside the click command to extend the command line interface. + Afterwards, it needs to be accessed in the different commands. + + When pytask is called from the API, the plugin manager needs to be created inside + the function, for example, :func:`~pytask.build` to ensure each call can start from + a blank slate and is able to register any plugins. + + """ + + _plugin_manager: PluginManager | None = None + + def create(self) -> PluginManager: + """Create the plugin manager.""" + self._plugin_manager = get_plugin_manager() + return self._plugin_manager + + def get(self) -> PluginManager: + """Get the plugin manager.""" + assert self._plugin_manager + return self._plugin_manager + + def store(self, pm: PluginManager) -> None: + """Store the plugin manager.""" + self._plugin_manager = pm + + +storage = _PluginManagerStorage() diff --git a/src/_pytask/profile.py b/src/_pytask/profile.py index 4c873f5d..3147042a 100644 --- a/src/_pytask/profile.py +++ b/src/_pytask/profile.py @@ -14,7 +14,6 @@ import click from _pytask.click import ColoredCommand from _pytask.click import EnumChoice -from _pytask.config import hookimpl from _pytask.console import console from _pytask.console import format_task_name from _pytask.database_utils import BaseTable @@ -25,7 +24,8 @@ from _pytask.node_protocols import PTask from _pytask.outcomes import ExitCode from _pytask.outcomes import TaskOutcome -from _pytask.pluginmanager import get_plugin_manager +from _pytask.pluginmanager import hookimpl +from _pytask.pluginmanager import storage from _pytask.session import Session from _pytask.traceback import Traceback from rich.table import Table @@ -111,10 +111,10 @@ def _create_or_update_runtime(task_signature: str, start: float, end: float) -> ) def profile(**raw_config: Any) -> NoReturn: """Show information about tasks like runtime and memory consumption of products.""" + pm = storage.get() raw_config["command"] = "profile" try: - pm = get_plugin_manager() config = pm.hook.pytask_configure(pm=pm, raw_config=raw_config) session = Session.from_config(config) diff --git a/src/_pytask/skipping.py b/src/_pytask/skipping.py index baba42d6..7264095e 100644 --- a/src/_pytask/skipping.py +++ b/src/_pytask/skipping.py @@ -4,7 +4,6 @@ from typing import Any from typing import TYPE_CHECKING -from _pytask.config import hookimpl from _pytask.dag_utils import descending_tasks from _pytask.mark import Mark from _pytask.mark_utils import get_marks @@ -13,6 +12,7 @@ from _pytask.outcomes import SkippedAncestorFailed from _pytask.outcomes import SkippedUnchanged from _pytask.outcomes import TaskOutcome +from _pytask.pluginmanager import hookimpl if TYPE_CHECKING: diff --git a/src/_pytask/task.py b/src/_pytask/task.py index 2309695b..895b2d2d 100644 --- a/src/_pytask/task.py +++ b/src/_pytask/task.py @@ -5,8 +5,8 @@ from typing import Callable from typing import TYPE_CHECKING -from _pytask.config import hookimpl from _pytask.console import format_strings_as_flat_tree +from _pytask.pluginmanager import hookimpl from _pytask.shared import find_duplicates from _pytask.task_utils import COLLECTED_TASKS from _pytask.task_utils import parse_collected_tasks_with_task_marker diff --git a/src/_pytask/warnings.py b/src/_pytask/warnings.py index 9ecee42f..e643c70f 100644 --- a/src/_pytask/warnings.py +++ b/src/_pytask/warnings.py @@ -7,8 +7,8 @@ from typing import TYPE_CHECKING import click -from _pytask.config import hookimpl from _pytask.console import console +from _pytask.pluginmanager import hookimpl from _pytask.warnings_utils import catch_warnings_for_item from _pytask.warnings_utils import parse_filterwarnings from _pytask.warnings_utils import WarningReport diff --git a/src/pytask/__init__.py b/src/pytask/__init__.py index 1b2f25c5..612dc537 100644 --- a/src/pytask/__init__.py +++ b/src/pytask/__init__.py @@ -6,6 +6,7 @@ from _pytask.build import build from _pytask.capture_utils import CaptureMethod from _pytask.capture_utils import ShowCapture +from _pytask.cli import cli from _pytask.click import ColoredCommand from _pytask.click import ColoredGroup from _pytask.click import EnumChoice @@ -15,7 +16,6 @@ from _pytask.collect_utils import produces from _pytask.compat import check_for_optional_program from _pytask.compat import import_optional_dependency -from _pytask.config import hookimpl from _pytask.console import console from _pytask.dag_command import build_dag from _pytask.data_catalog import DataCatalog @@ -59,6 +59,9 @@ from _pytask.outcomes import SkippedAncestorFailed from _pytask.outcomes import SkippedUnchanged from _pytask.outcomes import TaskOutcome +from _pytask.pluginmanager import get_plugin_manager +from _pytask.pluginmanager import hookimpl +from _pytask.pluginmanager import storage from _pytask.profile import Runtime from _pytask.reports import CollectionReport from _pytask.reports import DagReport @@ -74,10 +77,6 @@ from _pytask.warnings_utils import WarningReport -# This import must come last, otherwise a circular import occurs. -from _pytask.cli import cli # noreorder - - __all__ = [ "BaseTable", "CaptureMethod", @@ -136,6 +135,7 @@ "depends_on", "get_all_marks", "get_marks", + "get_plugin_manager", "has_mark", "hash_value", "hookimpl", @@ -149,6 +149,7 @@ "remove_internal_traceback_frames_from_exc_info", "remove_marks", "set_marks", + "storage", "task", "warning_record_to_str", ] diff --git a/tests/conftest.py b/tests/conftest.py index a17802a0..756200e7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ from click.testing import CliRunner from packaging import version from pytask import console +from pytask import storage @pytest.fixture(autouse=True) @@ -97,6 +98,7 @@ def _restore_sys_path_and_module_after_test_execution(): class CustomCliRunner(CliRunner): def invoke(self, *args, **kwargs): """Restore sys.path and sys.modules after an invocation.""" + storage.create() with restore_sys_path_and_module_after_test_execution(): return super().invoke(*args, **kwargs) diff --git a/tests/test_clean.py b/tests/test_clean.py index 085ec1b9..91c59f0e 100644 --- a/tests/test_clean.py +++ b/tests/test_clean.py @@ -9,6 +9,7 @@ from _pytask.git import init_repo from pytask import cli from pytask import ExitCode +from pytask import storage _PROJECT_TASK = """ @@ -86,6 +87,7 @@ def test_clean_database_ignored(project, runner): os.chdir(project) result = runner.invoke(cli, ["build"]) assert result.exit_code == ExitCode.OK + storage.create() result = runner.invoke(cli, ["clean"]) assert result.exit_code == ExitCode.OK os.chdir(cwd) diff --git a/tests/test_dag_command.py b/tests/test_dag_command.py index 30ae7561..96143212 100644 --- a/tests/test_dag_command.py +++ b/tests/test_dag_command.py @@ -115,7 +115,7 @@ def task_example(): pass tmp_path.joinpath("input.txt").touch() monkeypatch.setattr( - "_pytask.compat.importlib.import_module", + "_pytask.compat.import_module", lambda x: _raise_exc(ImportError("pygraphviz not found")), # noqa: ARG005 ) @@ -150,7 +150,7 @@ def task_create_graph(): tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) monkeypatch.setattr( - "_pytask.compat.importlib.import_module", + "_pytask.compat.import_module", lambda x: _raise_exc(ImportError("pygraphviz not found")), # noqa: ARG005 ) @@ -169,7 +169,7 @@ def test_raise_error_with_graph_via_cli_missing_optional_program( monkeypatch, tmp_path, runner ): monkeypatch.setattr( - "_pytask.compat.importlib.import_module", + "_pytask.compat.import_module", lambda x: None, # noqa: ARG005 ) monkeypatch.setattr("_pytask.compat.shutil.which", lambda x: None) # noqa: ARG005 @@ -200,7 +200,7 @@ def test_raise_error_with_graph_via_task_missing_optional_program( monkeypatch, tmp_path, runner ): monkeypatch.setattr( - "_pytask.compat.importlib.import_module", + "_pytask.compat.import_module", lambda x: None, # noqa: ARG005 ) monkeypatch.setattr("_pytask.compat.shutil.which", lambda x: None) # noqa: ARG005