Skip to content

Commit b3011f1

Browse files
committed
Formalize backends.
1 parent 833c17d commit b3011f1

File tree

2 files changed

+49
-8
lines changed

2 files changed

+49
-8
lines changed

docs/source/changes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
1515
reports.
1616
- {pull}`93` adds documentation on readthedocs.
1717
- {pull}`94` implements `ParallelBackend.NONE` as the default backend.
18+
- {pull}`95` formalizes parallel backends and apply wrappers for backends with threads
19+
or processes automatically.
1820

1921
## 0.4.1 - 2024-01-12
2022

src/pytask_parallel/backends.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,44 @@ class ParallelBackend(Enum):
9393
THREADS = "threads"
9494

9595

96+
class WorkerType(Enum):
97+
"""A type for workers that either spawned as threads or processes."""
98+
99+
THREADS = "threads"
100+
PROCESSES = "processes"
101+
102+
103+
@define
104+
class _ParallelBackend:
105+
builder: Callable[..., Executor]
106+
worker_type: WorkerType
107+
remote: bool
108+
109+
96110
@define
97111
class ParallelBackendRegistry:
98112
"""Registry for parallel backends."""
99113

100-
registry: ClassVar[dict[ParallelBackend, Callable[..., Executor]]] = {}
114+
registry: ClassVar[dict[ParallelBackend, _ParallelBackend]] = {}
101115

102116
def register_parallel_backend(
103-
self, kind: ParallelBackend, builder: Callable[..., Executor]
117+
self,
118+
kind: ParallelBackend,
119+
builder: Callable[..., Executor],
120+
*,
121+
worker_type: WorkerType | str = WorkerType.PROCESSES,
122+
remote: bool = False,
104123
) -> None:
105124
"""Register a parallel backend."""
106-
self.registry[kind] = builder
125+
self.registry[kind] = _ParallelBackend(
126+
builder=builder, worker_type=WorkerType(worker_type), remote=remote
127+
)
107128

108129
def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executor:
109130
"""Get a parallel backend."""
110131
__tracebackhide__ = True
111132
try:
112-
return self.registry[kind](n_workers=n_workers)
133+
return self.registry[kind].builder(n_workers=n_workers)
113134
except KeyError:
114135
msg = f"No registered parallel backend found for kind {kind.value!r}."
115136
raise ValueError(msg) from None
@@ -121,9 +142,27 @@ def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executo
121142
registry = ParallelBackendRegistry()
122143

123144

124-
registry.register_parallel_backend(ParallelBackend.DASK, _get_dask_executor)
125-
registry.register_parallel_backend(ParallelBackend.LOKY, _get_loky_executor)
126145
registry.register_parallel_backend(
127-
ParallelBackend.PROCESSES, _get_process_pool_executor
146+
ParallelBackend.DASK,
147+
_get_dask_executor,
148+
worker_type=WorkerType.PROCESSES,
149+
remote=False,
150+
)
151+
registry.register_parallel_backend(
152+
ParallelBackend.LOKY,
153+
_get_loky_executor,
154+
worker_type=WorkerType.PROCESSES,
155+
remote=False,
156+
)
157+
registry.register_parallel_backend(
158+
ParallelBackend.PROCESSES,
159+
_get_process_pool_executor,
160+
worker_type=WorkerType.PROCESSES,
161+
remote=False,
162+
)
163+
registry.register_parallel_backend(
164+
ParallelBackend.THREADS,
165+
_get_thread_pool_executor,
166+
worker_type=WorkerType.THREADS,
167+
remote=False,
128168
)
129-
registry.register_parallel_backend(ParallelBackend.THREADS, _get_thread_pool_executor)

0 commit comments

Comments
 (0)