Skip to content

Refactor ProcessPoolExecutor. #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and

## 0.3.1 - 2023-xx-xx

- {pull}`56` refactors the `ProcessPoolExecutor`.

## 0.3.0 - 2023-01-23

- {pull}`50` deprecates INI configurations and aligns the package with pytask v0.3.
Expand Down
33 changes: 31 additions & 2 deletions src/pytask_parallel/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,37 @@
from __future__ import annotations

import enum
from concurrent.futures import Future
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Callable

import cloudpickle


def deserialize_and_run_with_cloudpickle(
fn: Callable[..., Any], kwargs: dict[str, Any]
) -> Any:
"""Deserialize and execute a function and keyword arguments."""
deserialized_fn = cloudpickle.loads(fn)
deserialized_kwargs = cloudpickle.loads(kwargs)
return deserialized_fn(**deserialized_kwargs)


class CloudpickleProcessPoolExecutor(ProcessPoolExecutor):
"""Patches the standard executor to serialize functions with cloudpickle."""

# The type signature is wrong for version above Py3.7. Fix when 3.7 is deprecated.
def submit( # type: ignore[override]
self, fn: Callable[..., Any], *args: Any, **kwargs: Any # noqa: ARG002
) -> Future[Any]:
"""Submit a new task."""
return super().submit(
deserialize_and_run_with_cloudpickle,
fn=cloudpickle.dumps(fn),
kwargs=cloudpickle.dumps(kwargs),
)


try:
Expand All @@ -20,7 +49,7 @@ class ParallelBackendChoices(enum.Enum):
PARALLEL_BACKENDS_DEFAULT = ParallelBackendChoices.PROCESSES

PARALLEL_BACKENDS = {
ParallelBackendChoices.PROCESSES: ProcessPoolExecutor,
ParallelBackendChoices.PROCESSES: CloudpickleProcessPoolExecutor,
ParallelBackendChoices.THREADS: ThreadPoolExecutor,
}

Expand All @@ -36,7 +65,7 @@ class ParallelBackendChoices(enum.Enum): # type: ignore[no-redef]
PARALLEL_BACKENDS_DEFAULT = ParallelBackendChoices.PROCESSES

PARALLEL_BACKENDS = {
ParallelBackendChoices.PROCESSES: ProcessPoolExecutor,
ParallelBackendChoices.PROCESSES: CloudpickleProcessPoolExecutor,
ParallelBackendChoices.THREADS: ThreadPoolExecutor,
ParallelBackendChoices.LOKY: ( # type: ignore[attr-defined]
get_reusable_executor
Expand Down
15 changes: 4 additions & 11 deletions src/pytask_parallel/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from typing import List

import attr
import cloudpickle
from pybaum.tree_util import tree_map
from pytask import console
from pytask import ExecutionReport
Expand Down Expand Up @@ -179,13 +178,10 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
if session.config["n_workers"] > 1:
kwargs = _create_kwargs_for_task(task)

bytes_function = cloudpickle.dumps(task)
bytes_kwargs = cloudpickle.dumps(kwargs)

return session.config["_parallel_executor"].submit(
_unserialize_and_execute_task,
bytes_function=bytes_function,
bytes_kwargs=bytes_kwargs,
task=task,
kwargs=kwargs,
show_locals=session.config["show_locals"],
console_options=console.options,
session_filterwarnings=session.config["filterwarnings"],
Expand All @@ -196,8 +192,8 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:


def _unserialize_and_execute_task( # noqa: PLR0913
bytes_function: bytes,
bytes_kwargs: bytes,
task: Task,
kwargs: dict[str, Any],
show_locals: bool,
console_options: ConsoleOptions,
session_filterwarnings: tuple[str, ...],
Expand All @@ -212,9 +208,6 @@ def _unserialize_and_execute_task( # noqa: PLR0913
"""
__tracebackhide__ = True

task = cloudpickle.loads(bytes_function)
kwargs = cloudpickle.loads(bytes_kwargs)

with warnings.catch_warnings(record=True) as log:
# mypy can't infer that record=True means log is not None; help it.
assert log is not None
Expand Down