2
2
from __future__ import annotations
3
3
4
4
import enum
5
+ from concurrent .futures import Future
5
6
from concurrent .futures import ProcessPoolExecutor
6
7
from concurrent .futures import ThreadPoolExecutor
8
+ from typing import Any
9
+ from typing import Callable
10
+
11
+ import cloudpickle
12
+
13
+
14
+ def deserialize_and_run_with_cloudpickle (
15
+ fn : Callable [..., Any ], / , kwargs : dict [str , Any ]
16
+ ) -> Any :
17
+ """Deserialize and execute a function and keyword arguments."""
18
+ deserialized_fn = cloudpickle .loads (fn )
19
+ deserialized_kwargs = cloudpickle .loads (kwargs )
20
+ return deserialized_fn (** deserialized_kwargs )
21
+
22
+
23
+ class CloudpickleProcessPoolExecutor (ProcessPoolExecutor ):
24
+ """Patches the standard executor to serialize functions with cloudpickle."""
25
+
26
+ def submit (
27
+ self , fn : Callable [..., Any ], / , * args : Any , ** kwargs : Any # noqa: ARG002
28
+ ) -> Future [Any ]:
29
+ """Submit a new task."""
30
+ return super ().submit (
31
+ deserialize_and_run_with_cloudpickle ,
32
+ cloudpickle .dumps (fn ),
33
+ kwargs = cloudpickle .dumps (kwargs ),
34
+ )
7
35
8
36
9
37
try :
@@ -20,7 +48,7 @@ class ParallelBackendChoices(enum.Enum):
20
48
PARALLEL_BACKENDS_DEFAULT = ParallelBackendChoices .PROCESSES
21
49
22
50
PARALLEL_BACKENDS = {
23
- ParallelBackendChoices .PROCESSES : ProcessPoolExecutor ,
51
+ ParallelBackendChoices .PROCESSES : CloudpickleProcessPoolExecutor ,
24
52
ParallelBackendChoices .THREADS : ThreadPoolExecutor ,
25
53
}
26
54
@@ -36,7 +64,7 @@ class ParallelBackendChoices(enum.Enum): # type: ignore[no-redef]
36
64
PARALLEL_BACKENDS_DEFAULT = ParallelBackendChoices .PROCESSES
37
65
38
66
PARALLEL_BACKENDS = {
39
- ParallelBackendChoices .PROCESSES : ProcessPoolExecutor ,
67
+ ParallelBackendChoices .PROCESSES : CloudpickleProcessPoolExecutor ,
40
68
ParallelBackendChoices .THREADS : ThreadPoolExecutor ,
41
69
ParallelBackendChoices .LOKY : ( # type: ignore[attr-defined]
42
70
get_reusable_executor
0 commit comments