Skip to content

Formalize backends and apply wrappers automatically. #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
reports.
- {pull}`93` adds documentation on readthedocs.
- {pull}`94` implements `ParallelBackend.NONE` as the default backend.
- {pull}`95` formalizes parallel backends and apply wrappers for backends with threads
or processes automatically.

## 0.4.1 - 2024-01-12

Expand Down
53 changes: 40 additions & 13 deletions src/pytask_parallel/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from attrs import define
from loky import get_reusable_executor

__all__ = ["ParallelBackend", "ParallelBackendRegistry", "registry"]
__all__ = ["ParallelBackend", "ParallelBackendRegistry", "WorkerType", "registry"]


def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any:
Expand Down Expand Up @@ -93,37 +93,64 @@ class ParallelBackend(Enum):
THREADS = "threads"


class WorkerType(Enum):
"""A type for workers that either spawned as threads or processes."""

THREADS = "threads"
PROCESSES = "processes"


@define
class _ParallelBackend:
builder: Callable[..., Executor]
worker_type: WorkerType
remote: bool


@define
class ParallelBackendRegistry:
"""Registry for parallel backends."""

registry: ClassVar[dict[ParallelBackend, Callable[..., Executor]]] = {}
registry: ClassVar[dict[ParallelBackend, _ParallelBackend]] = {}

def register_parallel_backend(
self, kind: ParallelBackend, builder: Callable[..., Executor]
self,
kind: ParallelBackend,
builder: Callable[..., Executor],
*,
worker_type: WorkerType | str = WorkerType.PROCESSES,
remote: bool = False,
) -> None:
"""Register a parallel backend."""
self.registry[kind] = builder
self.registry[kind] = _ParallelBackend(
builder=builder, worker_type=WorkerType(worker_type), remote=remote
)

def get_parallel_backend(self, kind: ParallelBackend, n_workers: int) -> Executor:
"""Get a parallel backend."""
__tracebackhide__ = True
try:
return self.registry[kind](n_workers=n_workers)
return self.registry[kind].builder(n_workers=n_workers)
except KeyError:
msg = f"No registered parallel backend found for kind {kind.value!r}."
raise ValueError(msg) from None
except Exception as e: # noqa: BLE001
msg = f"Could not instantiate parallel backend {kind.value!r}."
raise ValueError(msg) from e


registry = ParallelBackendRegistry()
def reset(self) -> None:
"""Register the default backends."""
self.registry.clear()
for parallel_backend, builder, worker_type, remote in (
(ParallelBackend.DASK, _get_dask_executor, "processes", False),
(ParallelBackend.LOKY, _get_loky_executor, "processes", False),
(ParallelBackend.PROCESSES, _get_process_pool_executor, "processes", False),
(ParallelBackend.THREADS, _get_thread_pool_executor, "threads", False),
):
self.register_parallel_backend(
parallel_backend, builder, worker_type=worker_type, remote=remote
)


registry.register_parallel_backend(ParallelBackend.DASK, _get_dask_executor)
registry.register_parallel_backend(ParallelBackend.LOKY, _get_loky_executor)
registry.register_parallel_backend(
ParallelBackend.PROCESSES, _get_process_pool_executor
)
registry.register_parallel_backend(ParallelBackend.THREADS, _get_thread_pool_executor)
registry = ParallelBackendRegistry()
registry.reset()
14 changes: 0 additions & 14 deletions src/pytask_parallel/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@

from pytask import hookimpl

from pytask_parallel import custom
from pytask_parallel import dask
from pytask_parallel import execute
from pytask_parallel import logging
from pytask_parallel import processes
from pytask_parallel import threads
from pytask_parallel.backends import ParallelBackend


Expand Down Expand Up @@ -53,13 +49,3 @@ def pytask_post_parse(config: dict[str, Any]) -> None:
# Register parallel execute and logging hook.
config["pm"].register(logging)
config["pm"].register(execute)

# Register parallel backends.
if config["parallel_backend"] == ParallelBackend.THREADS:
config["pm"].register(threads)
elif config["parallel_backend"] == ParallelBackend.DASK:
config["pm"].register(dask)
elif config["parallel_backend"] == ParallelBackend.CUSTOM:
config["pm"].register(custom)
else:
config["pm"].register(processes)
27 changes: 0 additions & 27 deletions src/pytask_parallel/custom.py

This file was deleted.

208 changes: 0 additions & 208 deletions src/pytask_parallel/dask.py

This file was deleted.

Loading