Skip to content

Commit 2108f61

Browse files
committed
Fix.
1 parent 49d48ac commit 2108f61

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

src/pytask_parallel/backends.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,36 @@
22
from __future__ import annotations
33

44
import enum
5+
from concurrent.futures import Future
56
from concurrent.futures import ProcessPoolExecutor
67
from concurrent.futures import ThreadPoolExecutor
8+
from typing import Any
9+
from typing import Callable
10+
11+
import cloudpickle
12+
13+
14+
def deserialize_and_run_with_cloudpickle(
15+
fn: Callable[..., Any], /, kwargs: dict[str, Any]
16+
) -> Any:
17+
"""Deserialize and execute a function and keyword arguments."""
18+
deserialized_fn = cloudpickle.loads(fn)
19+
deserialized_kwargs = cloudpickle.loads(kwargs)
20+
return deserialized_fn(**deserialized_kwargs)
21+
22+
23+
class CloudpickleProcessPoolExecutor(ProcessPoolExecutor):
24+
"""Patches the standard executor to serialize functions with cloudpickle."""
25+
26+
def submit(
27+
self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any # noqa: ARG002
28+
) -> Future[Any]:
29+
"""Submit a new task."""
30+
return super().submit(
31+
deserialize_and_run_with_cloudpickle,
32+
cloudpickle.dumps(fn),
33+
kwargs=cloudpickle.dumps(kwargs),
34+
)
735

836

937
try:
@@ -20,7 +48,7 @@ class ParallelBackendChoices(enum.Enum):
2048
PARALLEL_BACKENDS_DEFAULT = ParallelBackendChoices.PROCESSES
2149

2250
PARALLEL_BACKENDS = {
23-
ParallelBackendChoices.PROCESSES: ProcessPoolExecutor,
51+
ParallelBackendChoices.PROCESSES: CloudpickleProcessPoolExecutor,
2452
ParallelBackendChoices.THREADS: ThreadPoolExecutor,
2553
}
2654

@@ -36,7 +64,7 @@ class ParallelBackendChoices(enum.Enum): # type: ignore[no-redef]
3664
PARALLEL_BACKENDS_DEFAULT = ParallelBackendChoices.PROCESSES
3765

3866
PARALLEL_BACKENDS = {
39-
ParallelBackendChoices.PROCESSES: ProcessPoolExecutor,
67+
ParallelBackendChoices.PROCESSES: CloudpickleProcessPoolExecutor,
4068
ParallelBackendChoices.THREADS: ThreadPoolExecutor,
4169
ParallelBackendChoices.LOKY: ( # type: ignore[attr-defined]
4270
get_reusable_executor

src/pytask_parallel/execute.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from typing import List
1313

1414
import attr
15-
import cloudpickle
1615
from pybaum.tree_util import tree_map
1716
from pytask import console
1817
from pytask import ExecutionReport
@@ -179,13 +178,10 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
179178
if session.config["n_workers"] > 1:
180179
kwargs = _create_kwargs_for_task(task)
181180

182-
bytes_function = cloudpickle.dumps(task)
183-
bytes_kwargs = cloudpickle.dumps(kwargs)
184-
185181
return session.config["_parallel_executor"].submit(
186182
_unserialize_and_execute_task,
187-
bytes_function=bytes_function,
188-
bytes_kwargs=bytes_kwargs,
183+
task=task,
184+
kwargs=kwargs,
189185
show_locals=session.config["show_locals"],
190186
console_options=console.options,
191187
session_filterwarnings=session.config["filterwarnings"],
@@ -196,8 +192,8 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
196192

197193

198194
def _unserialize_and_execute_task( # noqa: PLR0913
199-
bytes_function: bytes,
200-
bytes_kwargs: bytes,
195+
task: Task,
196+
kwargs: dict[str, Any],
201197
show_locals: bool,
202198
console_options: ConsoleOptions,
203199
session_filterwarnings: tuple[str, ...],
@@ -212,9 +208,6 @@ def _unserialize_and_execute_task( # noqa: PLR0913
212208
"""
213209
__tracebackhide__ = True
214210

215-
task = cloudpickle.loads(bytes_function)
216-
kwargs = cloudpickle.loads(bytes_kwargs)
217-
218211
with warnings.catch_warnings(record=True) as log:
219212
# mypy can't infer that record=True means log is not None; help it.
220213
assert log is not None

0 commit comments

Comments
 (0)