|
16 | 16 | from attrs import define
|
17 | 17 | from loky import get_reusable_executor
|
18 | 18 |
|
19 |
| -__all__ = ["ParallelBackend", "ParallelBackendRegistry", "registry"] |
| 19 | +__all__ = ["ParallelBackend", "ParallelBackendRegistry", "WorkerType", "registry"] |
20 | 20 |
|
21 | 21 |
|
22 | 22 | def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any:
|
@@ -93,37 +93,64 @@ class ParallelBackend(Enum):
|
93 | 93 | THREADS = "threads"
|
94 | 94 |
|
95 | 95 |
|
| 96 | +class WorkerType(Enum): |
| 97 | + """A type for workers that either spawned as threads or processes.""" |
| 98 | + |
| 99 | + THREADS = "threads" |
| 100 | + PROCESSES = "processes" |
| 101 | + |
| 102 | + |
| 103 | +@define |
| 104 | +class _ParallelBackend: |
| 105 | + builder: Callable[..., Executor] |
| 106 | + worker_type: WorkerType |
| 107 | + remote: bool |
| 108 | + |
| 109 | + |
96 | 110 | @define
|
97 | 111 | class ParallelBackendRegistry:
|
98 | 112 | """Registry for parallel backends."""
|
99 | 113 |
|
100 |
| - registry: ClassVar[dict[ParallelBackend, Callable[..., Executor]]] = {} |
| 114 | + registry: ClassVar[dict[ParallelBackend, _ParallelBackend]] = {} |
101 | 115 |
|
102 | 116 | def register_parallel_backend(
|
103 |
| - self, kind: ParallelBackend, builder: Callable[..., Executor] |
| 117 | + self, |
| 118 | + kind: ParallelBackend, |
| 119 | + builder: Callable[..., Executor], |
| 120 | + *, |
| 121 | + worker_type: WorkerType | str = WorkerType.PROCESSES, |
| 122 | + remote: bool = False, |
104 | 123 | ) -> None:
|
105 | 124 | """Register a parallel backend."""
|
106 |
| - self.registry[kind] = builder |
| 125 | + self.registry[kind] = _ParallelBackend( |
| 126 | + builder=builder, worker_type=WorkerType(worker_type), remote=remote |
| 127 | + ) |
107 | 128 |
|
108 | 129 | def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executor:
|
109 | 130 | """Get a parallel backend."""
|
110 | 131 | __tracebackhide__ = True
|
111 | 132 | try:
|
112 |
| - return self.registry[kind](n_workers=n_workers) |
| 133 | + return self.registry[kind].builder(n_workers=n_workers) |
113 | 134 | except KeyError:
|
114 | 135 | msg = f"No registered parallel backend found for kind {kind.value!r}."
|
115 | 136 | raise ValueError(msg) from None
|
116 | 137 | except Exception as e: # noqa: BLE001
|
117 | 138 | msg = f"Could not instantiate parallel backend {kind.value!r}."
|
118 | 139 | raise ValueError(msg) from e
|
119 | 140 |
|
120 |
| - |
121 |
| -registry = ParallelBackendRegistry() |
| 141 | + def reset(self) -> None: |
| 142 | + """Register the default backends.""" |
| 143 | + self.registry.clear() |
| 144 | + for parallel_backend, builder, worker_type, remote in ( |
| 145 | + (ParallelBackend.DASK, _get_dask_executor, "processes", False), |
| 146 | + (ParallelBackend.LOKY, _get_loky_executor, "processes", False), |
| 147 | + (ParallelBackend.PROCESSES, _get_process_pool_executor, "processes", False), |
| 148 | + (ParallelBackend.THREADS, _get_thread_pool_executor, "threads", False), |
| 149 | + ): |
| 150 | + self.register_parallel_backend( |
| 151 | + parallel_backend, builder, worker_type=worker_type, remote=remote |
| 152 | + ) |
122 | 153 |
|
123 | 154 |
|
124 |
| -registry.register_parallel_backend(ParallelBackend.DASK, _get_dask_executor) |
125 |
| -registry.register_parallel_backend(ParallelBackend.LOKY, _get_loky_executor) |
126 |
| -registry.register_parallel_backend( |
127 |
| - ParallelBackend.PROCESSES, _get_process_pool_executor |
128 |
| -) |
129 |
| -registry.register_parallel_backend(ParallelBackend.THREADS, _get_thread_pool_executor) |
| 155 | +registry = ParallelBackendRegistry() |
| 156 | +registry.reset() |
0 commit comments