diff --git a/CHANGES.md b/CHANGES.md index c0f13ac..f07f9b6 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,6 +8,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and ## 0.5.0 - 2024-xx-xx - {pull}`85` simplifies code since loky is a dependency. +- {pull}`86` adds support for dask. - {pull}`88` updates handling `Traceback`. - {pull}`89` restructures the package. - {pull}`92` redirects stdout and stderr from processes and loky and shows them in error diff --git a/environment.yml b/environment.yml index 4def0f2..dba1d27 100644 --- a/environment.yml +++ b/environment.yml @@ -16,6 +16,12 @@ dependencies: - loky - optree + # Additional dependencies + - universal_pathlib <0.2 + - s3fs>=2023.4.0 + - coiled + - distributed + # Misc - tox - ipywidgets diff --git a/pyproject.toml b/pyproject.toml index 4d65dd2..189d796 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,10 +29,12 @@ name = "Tobias Raabe" email = "raabe@posteo.de" [project.optional-dependencies] +dask = ["dask[complete]", "distributed"] test = [ - "nbmake", - "pytest", - "pytest-cov", + "pytask-parallel[all]", + "nbmake", + "pytest", + "pytest-cov", ] [project.readme] @@ -112,6 +114,7 @@ force-single-line = true convention = "numpy" [tool.pytest.ini_options] +addopts = ["--nbmake"] # Do not add src since it messes with the loading of pytask-parallel as a plugin. testpaths = ["tests"] markers = [ diff --git a/src/pytask_parallel/backends.py b/src/pytask_parallel/backends.py index 8b53f5c..9db8060 100644 --- a/src/pytask_parallel/backends.py +++ b/src/pytask_parallel/backends.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from concurrent.futures import Executor from concurrent.futures import Future from concurrent.futures import ProcessPoolExecutor @@ -12,6 +13,7 @@ from typing import ClassVar import cloudpickle +from attrs import define from loky import get_reusable_executor __all__ = ["ParallelBackend", "ParallelBackendRegistry", "registry"] @@ -27,7 +29,7 @@ def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any: class _CloudpickleProcessPoolExecutor(ProcessPoolExecutor): """Patches the standard executor to serialize functions with cloudpickle.""" - # The type signature is wrong for version above Py3.7. Fix when 3.7 is deprecated. + # The type signature is wrong for Python >3.8. Fix when support is dropped. def submit( # type: ignore[override] self, fn: Callable[..., Any], @@ -42,15 +44,54 @@ def submit( # type: ignore[override] ) +def _get_dask_executor(n_workers: int) -> Executor: + """Get an executor from a dask client.""" + _rich_traceback_omit = True + from pytask import import_optional_dependency + + distributed = import_optional_dependency("distributed") + try: + client = distributed.Client.current() + except ValueError: + client = distributed.Client(distributed.LocalCluster(n_workers=n_workers)) + else: + if client.cluster and len(client.cluster.workers) != n_workers: + warnings.warn( + "The number of workers in the dask cluster " + f"({len(client.cluster.workers)}) does not match the number of workers " + f"requested ({n_workers}). The requested number of workers will be " + "ignored.", + stacklevel=1, + ) + return client.get_executor() + + +def _get_loky_executor(n_workers: int) -> Executor: + """Get a loky executor.""" + return get_reusable_executor(max_workers=n_workers) + + +def _get_process_pool_executor(n_workers: int) -> Executor: + """Get a process pool executor.""" + return _CloudpickleProcessPoolExecutor(max_workers=n_workers) + + +def _get_thread_pool_executor(n_workers: int) -> Executor: + """Get a thread pool executor.""" + return ThreadPoolExecutor(max_workers=n_workers) + + class ParallelBackend(Enum): """Choices for parallel backends.""" CUSTOM = "custom" + DASK = "dask" LOKY = "loky" PROCESSES = "processes" THREADS = "threads" +@define class ParallelBackendRegistry: """Registry for parallel backends.""" @@ -68,23 +109,19 @@ def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executo try: return self.registry[kind](n_workers=n_workers) except KeyError: - msg = f"No registered parallel backend found for kind {kind}." + msg = f"No registered parallel backend found for kind {kind.value!r}." raise ValueError(msg) from None except Exception as e: # noqa: BLE001 - msg = f"Could not instantiate parallel backend {kind.value}." + msg = f"Could not instantiate parallel backend {kind.value!r}." raise ValueError(msg) from e registry = ParallelBackendRegistry() +registry.register_parallel_backend(ParallelBackend.DASK, _get_dask_executor) +registry.register_parallel_backend(ParallelBackend.LOKY, _get_loky_executor) registry.register_parallel_backend( - ParallelBackend.PROCESSES, - lambda n_workers: _CloudpickleProcessPoolExecutor(max_workers=n_workers), -) -registry.register_parallel_backend( - ParallelBackend.THREADS, lambda n_workers: ThreadPoolExecutor(max_workers=n_workers) -) -registry.register_parallel_backend( - ParallelBackend.LOKY, lambda n_workers: get_reusable_executor(max_workers=n_workers) + ParallelBackend.PROCESSES, _get_process_pool_executor ) +registry.register_parallel_backend(ParallelBackend.THREADS, _get_thread_pool_executor) diff --git a/src/pytask_parallel/config.py b/src/pytask_parallel/config.py index 8e58f55..c60c828 100644 --- a/src/pytask_parallel/config.py +++ b/src/pytask_parallel/config.py @@ -8,7 +8,9 @@ from pytask import hookimpl from pytask_parallel import custom +from pytask_parallel import dask from pytask_parallel import execute +from pytask_parallel import logging from pytask_parallel import processes from pytask_parallel import threads from pytask_parallel.backends import ParallelBackend @@ -37,19 +39,20 @@ def pytask_parse_config(config: dict[str, Any]) -> None: @hookimpl(trylast=True) def pytask_post_parse(config: dict[str, Any]) -> None: """Register the parallel backend if debugging is not enabled.""" + # Deactivate parallel execution if debugger, trace or dry-run is used. if config["pdb"] or config["trace"] or config["dry_run"]: - config["n_workers"] = 1 + return - # Register parallel execute hook. - if config["n_workers"] > 1 or config["parallel_backend"] == ParallelBackend.CUSTOM: - config["pm"].register(execute) + # Register parallel execute and logging hook. + config["pm"].register(logging) + config["pm"].register(execute) # Register parallel backends. - if config["n_workers"] > 1: - if config["parallel_backend"] == ParallelBackend.THREADS: - config["pm"].register(threads) - else: - config["pm"].register(processes) - - if config["parallel_backend"] == ParallelBackend.CUSTOM: + if config["parallel_backend"] == ParallelBackend.THREADS: + config["pm"].register(threads) + elif config["parallel_backend"] == ParallelBackend.DASK: + config["pm"].register(dask) + elif config["parallel_backend"] == ParallelBackend.CUSTOM: config["pm"].register(custom) + else: + config["pm"].register(processes) diff --git a/src/pytask_parallel/dask.py b/src/pytask_parallel/dask.py new file mode 100644 index 0000000..f71cf7d --- /dev/null +++ b/src/pytask_parallel/dask.py @@ -0,0 +1,208 @@ +"""Contains functions for the dask backend.""" + +from __future__ import annotations + +import inspect +import sys +import warnings +from contextlib import redirect_stderr +from contextlib import redirect_stdout +from functools import partial +from io import StringIO +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable + +import cloudpickle +from pytask import Mark +from pytask import PTask +from pytask import PythonNode +from pytask import Session +from pytask import Traceback +from pytask import WarningReport +from pytask import console +from pytask import get_marks +from pytask import hookimpl +from pytask import parse_warning_filter +from pytask import warning_record_to_str +from pytask.tree_util import PyTree +from pytask.tree_util import tree_map + +from pytask_parallel.utils import create_kwargs_for_task +from pytask_parallel.utils import handle_task_function_return + +if TYPE_CHECKING: + from concurrent.futures import Future + from pathlib import Path + from types import ModuleType + from types import TracebackType + + from rich.console import ConsoleOptions + + +@hookimpl +def pytask_execute_task(session: Session, task: PTask) -> Future[Any]: + """Execute a task. + + Since threads have shared memory, it is not necessary to pickle and unpickle the + task. + + """ + # Task modules are dynamically loaded and added to `sys.modules`. Thus, cloudpickle + # believes the module of the task function is also importable in the child process. + # We have to register the module as dynamic again, so that cloudpickle will pickle + # it with the function. See cloudpickle#417, pytask#373 and pytask#374. + task_module = _get_module(task.function, getattr(task, "path", None)) + cloudpickle.register_pickle_by_value(task_module) + + kwargs = create_kwargs_for_task(task) + return session.config["_parallel_executor"].submit( + _execute_task, + task=task, + kwargs=kwargs, + show_locals=session.config["show_locals"], + console_options=console.options, + session_filterwarnings=session.config["filterwarnings"], + task_filterwarnings=get_marks(task, "filterwarnings"), + ) + + +def _raise_exception_on_breakpoint(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 + msg = ( + "You cannot use 'breakpoint()' or 'pdb.set_trace()' while parallelizing the " + "execution of tasks with pytask-parallel. Please, remove the breakpoint or run " + "the task without parallelization to debug it." + ) + raise RuntimeError(msg) + + +def _patch_set_trace_and_breakpoint() -> None: + """Patch :func:`pdb.set_trace` and :func:`breakpoint`. + + Patch sys.breakpointhook to intercept any call of breakpoint() and pdb.set_trace in + a subprocess and print a better exception message. + + """ + import pdb # noqa: T100 + import sys + + pdb.set_trace = _raise_exception_on_breakpoint + sys.breakpointhook = _raise_exception_on_breakpoint + + +def _execute_task( # noqa: PLR0913 + task: PTask, + kwargs: dict[str, Any], + show_locals: bool, # noqa: FBT001 + console_options: ConsoleOptions, + session_filterwarnings: tuple[str, ...], + task_filterwarnings: tuple[Mark, ...], +) -> tuple[ + PyTree[PythonNode | None], + list[WarningReport], + tuple[type[BaseException], BaseException, str] | None, + str, + str, +]: + """Unserialize and execute task. + + This function receives bytes and unpickles them to a task which is them execute in a + spawned process or thread. + + """ + # Hide this function from tracebacks. + __tracebackhide__ = True + + # Patch set_trace and breakpoint to show a better error message. + _patch_set_trace_and_breakpoint() + + captured_stdout_buffer = StringIO() + captured_stderr_buffer = StringIO() + + # Catch warnings and store them in a list. + with warnings.catch_warnings(record=True) as log, redirect_stdout( + captured_stdout_buffer + ), redirect_stderr(captured_stderr_buffer): + # Apply global filterwarnings. + for arg in session_filterwarnings: + warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) + + # Apply filters from "filterwarnings" marks + for mark in task_filterwarnings: + for arg in mark.args: + warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) + + try: + out = task.execute(**kwargs) + except Exception: # noqa: BLE001 + exc_info = sys.exc_info() + processed_exc_info = _process_exception( + exc_info, show_locals, console_options + ) + else: + # Save products. + handle_task_function_return(task, out) + processed_exc_info = None + + task_display_name = getattr(task, "display_name", task.name) + warning_reports = [] + for warning_message in log: + fs_location = warning_message.filename, warning_message.lineno + warning_reports.append( + WarningReport( + message=warning_record_to_str(warning_message), + fs_location=fs_location, + id_=task_display_name, + ) + ) + + captured_stdout_buffer.seek(0) + captured_stderr_buffer.seek(0) + captured_stdout = captured_stdout_buffer.read() + captured_stderr = captured_stderr_buffer.read() + captured_stdout_buffer.close() + captured_stderr_buffer.close() + + # Collect all PythonNodes that are products to pass values back to the main process. + python_nodes = tree_map( + lambda x: x if isinstance(x, PythonNode) else None, task.produces + ) + + return ( + python_nodes, + warning_reports, + processed_exc_info, + captured_stdout, + captured_stderr, + ) + + +def _process_exception( + exc_info: tuple[type[BaseException], BaseException, TracebackType | None], + show_locals: bool, # noqa: FBT001 + console_options: ConsoleOptions, +) -> tuple[type[BaseException], BaseException, str]: + """Process the exception and convert the traceback to a string.""" + traceback = Traceback(exc_info, show_locals=show_locals) + segments = console.render(traceback, options=console_options) + text = "".join(segment.text for segment in segments) + return (*exc_info[:2], text) + + +def _get_module(func: Callable[..., Any], path: Path | None) -> ModuleType: + """Get the module of a python function. + + ``functools.partial`` obfuscates the module of the function and + ``inspect.getmodule`` returns :mod`functools`. Therefore, we recover the original + function. + + We use the path from the task module to aid the search although it is not clear + whether it helps. + + """ + if isinstance(func, partial): + func = func.func + + if path: + return inspect.getmodule(func, path.as_posix()) + return inspect.getmodule(func) diff --git a/src/pytask_parallel/execute.py b/src/pytask_parallel/execute.py index cb0ea0c..e4ce8e0 100644 --- a/src/pytask_parallel/execute.py +++ b/src/pytask_parallel/execute.py @@ -39,18 +39,19 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 """ __tracebackhide__ = True - reports = session.execution_reports running_tasks: dict[str, Future[Any]] = {} - parallel_backend = registry.get_parallel_backend( + # The executor can only be created after the collection to give users the + # possibility to inject their own executors. + session.config["_parallel_executor"] = registry.get_parallel_backend( session.config["parallel_backend"], n_workers=session.config["n_workers"] ) - with parallel_backend as executor: - session.config["_parallel_executor"] = executor + with session.config["_parallel_executor"]: sleeper = _Sleeper() + i = 0 while session.scheduler.is_active(): try: newly_collected_reports = [] @@ -150,6 +151,8 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 except KeyboardInterrupt: break + i += 1 + return True diff --git a/src/pytask_parallel/logging.py b/src/pytask_parallel/logging.py index 8ee5131..d18c786 100644 --- a/src/pytask_parallel/logging.py +++ b/src/pytask_parallel/logging.py @@ -10,6 +10,4 @@ @hookimpl(trylast=True) def pytask_log_session_header(session: Session) -> None: """Add a note for how many workers are spawned.""" - n_workers = session.config["n_workers"] - if n_workers > 1: - console.print(f"Started {n_workers} workers.") + console.print(f"Starting {session.config['n_workers']} workers.") diff --git a/src/pytask_parallel/plugin.py b/src/pytask_parallel/plugin.py index 353cdef..a634629 100644 --- a/src/pytask_parallel/plugin.py +++ b/src/pytask_parallel/plugin.py @@ -8,7 +8,6 @@ from pytask_parallel import build from pytask_parallel import config -from pytask_parallel import logging if TYPE_CHECKING: from pluggy import PluginManager @@ -19,4 +18,3 @@ def pytask_add_hooks(pm: PluginManager) -> None: """Register plugins.""" pm.register(build) pm.register(config) - pm.register(logging) diff --git a/tests/test_backends.py b/tests/test_backends.py index b722ffd..92ecc6b 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -29,7 +29,7 @@ def task_example(): pass result = runner.invoke(cli, [tmp_path.as_posix(), "--parallel-backend", "custom"]) assert result.exit_code == ExitCode.FAILED assert "ERROR" in result.output - assert "Could not instantiate parallel backend custom." in result.output + assert "Could not instantiate parallel backend 'custom'." in result.output @pytest.mark.end_to_end() diff --git a/tests/test_capture.py b/tests/test_capture.py index 4659f7b..46dc5af 100644 --- a/tests/test_capture.py +++ b/tests/test_capture.py @@ -8,7 +8,17 @@ @pytest.mark.end_to_end() @pytest.mark.parametrize( - "parallel_backend", [ParallelBackend.PROCESSES, ParallelBackend.LOKY] + "parallel_backend", + [ + pytest.param( + ParallelBackend.DASK, + marks=pytest.mark.skip( + reason="dask cannot handle dynamically imported modules." + ), + ), + ParallelBackend.LOKY, + ParallelBackend.PROCESSES, + ], ) @pytest.mark.parametrize("show_capture", ["no", "stdout", "stderr", "all"]) def test_show_capture(tmp_path, runner, parallel_backend, show_capture): diff --git a/tests/test_config.py b/tests/test_config.py index f73fded..9e2ce2c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -15,7 +15,7 @@ [ (False, 1, 1), (True, 1, 1), - (True, 2, 1), + (True, 2, 2), (False, 2, 2), (False, "auto", os.cpu_count() - 1), ], @@ -36,6 +36,13 @@ def test_interplay_between_debugging_and_parallel(tmp_path, pdb, n_workers, expe ] + [ ("parallel_backend", parallel_backend, ExitCode.OK) + if parallel_backend != ParallelBackend.DASK + else pytest.param( + "parallel_backend", + "dask", + ExitCode.CONFIGURATION_FAILED, + marks=pytest.mark.skip(reason="Dask is not yet supported"), + ) for parallel_backend in ParallelBackend ], ) @@ -54,5 +61,5 @@ def test_reading_values_from_config_file( assert session.exit_code == exit_code if value == "auto": value = os.cpu_count() - 1 - if session.exit_code == ExitCode.OK: + if value != "unknown_backend": assert session.config[configuration_option] == value diff --git a/tests/test_execute.py b/tests/test_execute.py index 410572a..275e502 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -12,7 +12,17 @@ from tests.conftest import restore_sys_path_and_module_after_test_execution -_IMPLEMENTED_BACKENDS = [p for p in ParallelBackend if p != ParallelBackend.CUSTOM] +_IMPLEMENTED_BACKENDS = [ + pytest.param( + ParallelBackend.DASK, + marks=pytest.mark.skip( + reason="dask cannot handle dynamically imported modules." + ), + ), + ParallelBackend.LOKY, + ParallelBackend.PROCESSES, + ParallelBackend.THREADS, +] @pytest.mark.end_to_end() @@ -277,11 +287,7 @@ def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_back @pytest.mark.parametrize("code", ["breakpoint()", "import pdb; pdb.set_trace()"]) @pytest.mark.parametrize( "parallel_backend", - [ - i - for i in ParallelBackend - if i not in (ParallelBackend.THREADS, ParallelBackend.CUSTOM) - ], + [i for i in _IMPLEMENTED_BACKENDS if i != ParallelBackend.THREADS], ) def test_raise_error_on_breakpoint(runner, tmp_path, code, parallel_backend): tmp_path.joinpath("task_example.py").write_text(f"def task_example(): {code}") diff --git a/tox.ini b/tox.ini index 0be302e..a0c7cd6 100644 --- a/tox.ini +++ b/tox.ini @@ -10,4 +10,4 @@ extras = test deps = git+https://github.com/pytask-dev/pytask.git@main commands = - pytest --nbmake {posargs} + pytest {posargs}