diff --git a/CHANGES.md b/CHANGES.md index 86fd0bc..9ecfcd0 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -7,6 +7,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and ## 0.3.1 - 2023-xx-xx +- {pull}`56` refactors the `ProcessPoolExecutor`. + ## 0.3.0 - 2023-01-23 - {pull}`50` deprecates INI configurations and aligns the package with pytask v0.3. diff --git a/src/pytask_parallel/backends.py b/src/pytask_parallel/backends.py index 6d81c3e..3a4d9de 100644 --- a/src/pytask_parallel/backends.py +++ b/src/pytask_parallel/backends.py @@ -2,8 +2,37 @@ from __future__ import annotations import enum +from concurrent.futures import Future from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ThreadPoolExecutor +from typing import Any +from typing import Callable + +import cloudpickle + + +def deserialize_and_run_with_cloudpickle( + fn: Callable[..., Any], kwargs: dict[str, Any] +) -> Any: + """Deserialize and execute a function and keyword arguments.""" + deserialized_fn = cloudpickle.loads(fn) + deserialized_kwargs = cloudpickle.loads(kwargs) + return deserialized_fn(**deserialized_kwargs) + + +class CloudpickleProcessPoolExecutor(ProcessPoolExecutor): + """Patches the standard executor to serialize functions with cloudpickle.""" + + # The type signature is wrong for version above Py3.7. Fix when 3.7 is deprecated. + def submit( # type: ignore[override] + self, fn: Callable[..., Any], *args: Any, **kwargs: Any # noqa: ARG002 + ) -> Future[Any]: + """Submit a new task.""" + return super().submit( + deserialize_and_run_with_cloudpickle, + fn=cloudpickle.dumps(fn), + kwargs=cloudpickle.dumps(kwargs), + ) try: @@ -20,7 +49,7 @@ class ParallelBackendChoices(enum.Enum): PARALLEL_BACKENDS_DEFAULT = ParallelBackendChoices.PROCESSES PARALLEL_BACKENDS = { - ParallelBackendChoices.PROCESSES: ProcessPoolExecutor, + ParallelBackendChoices.PROCESSES: CloudpickleProcessPoolExecutor, ParallelBackendChoices.THREADS: ThreadPoolExecutor, } @@ -36,7 +65,7 @@ class ParallelBackendChoices(enum.Enum): # type: ignore[no-redef] PARALLEL_BACKENDS_DEFAULT = ParallelBackendChoices.PROCESSES PARALLEL_BACKENDS = { - ParallelBackendChoices.PROCESSES: ProcessPoolExecutor, + ParallelBackendChoices.PROCESSES: CloudpickleProcessPoolExecutor, ParallelBackendChoices.THREADS: ThreadPoolExecutor, ParallelBackendChoices.LOKY: ( # type: ignore[attr-defined] get_reusable_executor diff --git a/src/pytask_parallel/execute.py b/src/pytask_parallel/execute.py index 72e7e0f..8f72785 100644 --- a/src/pytask_parallel/execute.py +++ b/src/pytask_parallel/execute.py @@ -12,7 +12,6 @@ from typing import List import attr -import cloudpickle from pybaum.tree_util import tree_map from pytask import console from pytask import ExecutionReport @@ -179,13 +178,10 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None: if session.config["n_workers"] > 1: kwargs = _create_kwargs_for_task(task) - bytes_function = cloudpickle.dumps(task) - bytes_kwargs = cloudpickle.dumps(kwargs) - return session.config["_parallel_executor"].submit( _unserialize_and_execute_task, - bytes_function=bytes_function, - bytes_kwargs=bytes_kwargs, + task=task, + kwargs=kwargs, show_locals=session.config["show_locals"], console_options=console.options, session_filterwarnings=session.config["filterwarnings"], @@ -196,8 +192,8 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None: def _unserialize_and_execute_task( # noqa: PLR0913 - bytes_function: bytes, - bytes_kwargs: bytes, + task: Task, + kwargs: dict[str, Any], show_locals: bool, console_options: ConsoleOptions, session_filterwarnings: tuple[str, ...], @@ -212,9 +208,6 @@ def _unserialize_and_execute_task( # noqa: PLR0913 """ __tracebackhide__ = True - task = cloudpickle.loads(bytes_function) - kwargs = cloudpickle.loads(bytes_kwargs) - with warnings.catch_warnings(record=True) as log: # mypy can't infer that record=True means log is not None; help it. assert log is not None