@@ -93,23 +93,44 @@ class ParallelBackend(Enum):
93
93
THREADS = "threads"
94
94
95
95
96
+ class WorkerType (Enum ):
97
+ """A type for workers that either spawned as threads or processes."""
98
+
99
+ THREADS = "threads"
100
+ PROCESSES = "processes"
101
+
102
+
103
+ @define
104
+ class _ParallelBackend :
105
+ builder : Callable [..., Executor ]
106
+ worker_type : WorkerType
107
+ remote : bool
108
+
109
+
96
110
@define
97
111
class ParallelBackendRegistry :
98
112
"""Registry for parallel backends."""
99
113
100
- registry : ClassVar [dict [ParallelBackend , Callable [..., Executor ] ]] = {}
114
+ registry : ClassVar [dict [ParallelBackend , _ParallelBackend ]] = {}
101
115
102
116
def register_parallel_backend (
103
- self , kind : ParallelBackend , builder : Callable [..., Executor ]
117
+ self ,
118
+ kind : ParallelBackend ,
119
+ builder : Callable [..., Executor ],
120
+ * ,
121
+ worker_type : WorkerType | str = WorkerType .PROCESSES ,
122
+ remote : bool = False ,
104
123
) -> None :
105
124
"""Register a parallel backend."""
106
- self .registry [kind ] = builder
125
+ self .registry [kind ] = _ParallelBackend (
126
+ builder = builder , worker_type = WorkerType (worker_type ), remote = remote
127
+ )
107
128
108
129
def get_parallel_backend (self , kind : ParallelBackend , n_workers : int ) -> Executor :
109
130
"""Get a parallel backend."""
110
131
__tracebackhide__ = True
111
132
try :
112
- return self .registry [kind ](n_workers = n_workers )
133
+ return self .registry [kind ]. builder (n_workers = n_workers )
113
134
except KeyError :
114
135
msg = f"No registered parallel backend found for kind { kind .value !r} ."
115
136
raise ValueError (msg ) from None
@@ -121,9 +142,27 @@ def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executo
121
142
registry = ParallelBackendRegistry ()
122
143
123
144
124
- registry .register_parallel_backend (ParallelBackend .DASK , _get_dask_executor )
125
- registry .register_parallel_backend (ParallelBackend .LOKY , _get_loky_executor )
126
145
registry .register_parallel_backend (
127
- ParallelBackend .PROCESSES , _get_process_pool_executor
146
+ ParallelBackend .DASK ,
147
+ _get_dask_executor ,
148
+ worker_type = WorkerType .PROCESSES ,
149
+ remote = False ,
150
+ )
151
+ registry .register_parallel_backend (
152
+ ParallelBackend .LOKY ,
153
+ _get_loky_executor ,
154
+ worker_type = WorkerType .PROCESSES ,
155
+ remote = False ,
156
+ )
157
+ registry .register_parallel_backend (
158
+ ParallelBackend .PROCESSES ,
159
+ _get_process_pool_executor ,
160
+ worker_type = WorkerType .PROCESSES ,
161
+ remote = False ,
162
+ )
163
+ registry .register_parallel_backend (
164
+ ParallelBackend .THREADS ,
165
+ _get_thread_pool_executor ,
166
+ worker_type = WorkerType .THREADS ,
167
+ remote = False ,
128
168
)
129
- registry .register_parallel_backend (ParallelBackend .THREADS , _get_thread_pool_executor )
0 commit comments