Skip to content

Commit ef2fbca

Browse files
authored
Formalize backends and apply wrappers automatically. (#95)
1 parent 833c17d commit ef2fbca

File tree

13 files changed

+262
-516
lines changed

13 files changed

+262
-516
lines changed

docs/source/changes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
1515
reports.
1616
- {pull}`93` adds documentation on readthedocs.
1717
- {pull}`94` implements `ParallelBackend.NONE` as the default backend.
18+
- {pull}`95` formalizes parallel backends and apply wrappers for backends with threads
19+
or processes automatically.
1820

1921
## 0.4.1 - 2024-01-12
2022

src/pytask_parallel/backends.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from attrs import define
1717
from loky import get_reusable_executor
1818

19-
__all__ = ["ParallelBackend", "ParallelBackendRegistry", "registry"]
19+
__all__ = ["ParallelBackend", "ParallelBackendRegistry", "WorkerType", "registry"]
2020

2121

2222
def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any:
@@ -93,37 +93,64 @@ class ParallelBackend(Enum):
9393
THREADS = "threads"
9494

9595

96+
class WorkerType(Enum):
97+
"""A type for workers that either spawned as threads or processes."""
98+
99+
THREADS = "threads"
100+
PROCESSES = "processes"
101+
102+
103+
@define
104+
class _ParallelBackend:
105+
builder: Callable[..., Executor]
106+
worker_type: WorkerType
107+
remote: bool
108+
109+
96110
@define
97111
class ParallelBackendRegistry:
98112
"""Registry for parallel backends."""
99113

100-
registry: ClassVar[dict[ParallelBackend, Callable[..., Executor]]] = {}
114+
registry: ClassVar[dict[ParallelBackend, _ParallelBackend]] = {}
101115

102116
def register_parallel_backend(
103-
self, kind: ParallelBackend, builder: Callable[..., Executor]
117+
self,
118+
kind: ParallelBackend,
119+
builder: Callable[..., Executor],
120+
*,
121+
worker_type: WorkerType | str = WorkerType.PROCESSES,
122+
remote: bool = False,
104123
) -> None:
105124
"""Register a parallel backend."""
106-
self.registry[kind] = builder
125+
self.registry[kind] = _ParallelBackend(
126+
builder=builder, worker_type=WorkerType(worker_type), remote=remote
127+
)
107128

108129
def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executor:
109130
"""Get a parallel backend."""
110131
__tracebackhide__ = True
111132
try:
112-
return self.registry[kind](n_workers=n_workers)
133+
return self.registry[kind].builder(n_workers=n_workers)
113134
except KeyError:
114135
msg = f"No registered parallel backend found for kind {kind.value!r}."
115136
raise ValueError(msg) from None
116137
except Exception as e: # noqa: BLE001
117138
msg = f"Could not instantiate parallel backend {kind.value!r}."
118139
raise ValueError(msg) from e
119140

120-
121-
registry = ParallelBackendRegistry()
141+
def reset(self) -> None:
142+
"""Register the default backends."""
143+
self.registry.clear()
144+
for parallel_backend, builder, worker_type, remote in (
145+
(ParallelBackend.DASK, _get_dask_executor, "processes", False),
146+
(ParallelBackend.LOKY, _get_loky_executor, "processes", False),
147+
(ParallelBackend.PROCESSES, _get_process_pool_executor, "processes", False),
148+
(ParallelBackend.THREADS, _get_thread_pool_executor, "threads", False),
149+
):
150+
self.register_parallel_backend(
151+
parallel_backend, builder, worker_type=worker_type, remote=remote
152+
)
122153

123154

124-
registry.register_parallel_backend(ParallelBackend.DASK, _get_dask_executor)
125-
registry.register_parallel_backend(ParallelBackend.LOKY, _get_loky_executor)
126-
registry.register_parallel_backend(
127-
ParallelBackend.PROCESSES, _get_process_pool_executor
128-
)
129-
registry.register_parallel_backend(ParallelBackend.THREADS, _get_thread_pool_executor)
155+
registry = ParallelBackendRegistry()
156+
registry.reset()

src/pytask_parallel/config.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,8 @@
77

88
from pytask import hookimpl
99

10-
from pytask_parallel import custom
11-
from pytask_parallel import dask
1210
from pytask_parallel import execute
1311
from pytask_parallel import logging
14-
from pytask_parallel import processes
15-
from pytask_parallel import threads
1612
from pytask_parallel.backends import ParallelBackend
1713

1814

@@ -53,13 +49,3 @@ def pytask_post_parse(config: dict[str, Any]) -> None:
5349
# Register parallel execute and logging hook.
5450
config["pm"].register(logging)
5551
config["pm"].register(execute)
56-
57-
# Register parallel backends.
58-
if config["parallel_backend"] == ParallelBackend.THREADS:
59-
config["pm"].register(threads)
60-
elif config["parallel_backend"] == ParallelBackend.DASK:
61-
config["pm"].register(dask)
62-
elif config["parallel_backend"] == ParallelBackend.CUSTOM:
63-
config["pm"].register(custom)
64-
else:
65-
config["pm"].register(processes)

src/pytask_parallel/custom.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

src/pytask_parallel/dask.py

Lines changed: 0 additions & 208 deletions
This file was deleted.

0 commit comments

Comments
 (0)