Skip to content

Simplify code since loky is a dependency. #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ chronological order. Releases follow [semantic versioning](https://semver.org/)
releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
[Anaconda.org](https://anaconda.org/conda-forge/pytask-parallel).

## 0.4.2 - 2024-xx-xx

- {pull}`85` simplifies code since loky is a dependency.

## 0.4.1 - 2024-01-12

- {pull}`72` moves the project to `pyproject.toml`.
Expand Down
44 changes: 12 additions & 32 deletions src/pytask_parallel/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from __future__ import annotations

import enum
from concurrent.futures import Future
from concurrent.futures import ProcessPoolExecutor
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from typing import Any
from typing import Callable

import cloudpickle
from loky import get_reusable_executor


def deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any:
Expand Down Expand Up @@ -37,37 +38,16 @@ def submit( # type: ignore[override]
)


try:
from loky import get_reusable_executor
class ParallelBackend(Enum):
"""Choices for parallel backends."""

except ImportError:
PROCESSES = "processes"
THREADS = "threads"
LOKY = "loky"

class ParallelBackend(enum.Enum):
"""Choices for parallel backends."""

PROCESSES = "processes"
THREADS = "threads"

PARALLEL_BACKENDS_DEFAULT = ParallelBackend.PROCESSES

PARALLEL_BACKENDS = {
ParallelBackend.PROCESSES: CloudpickleProcessPoolExecutor,
ParallelBackend.THREADS: ThreadPoolExecutor,
}

else:

class ParallelBackend(enum.Enum): # type: ignore[no-redef]
"""Choices for parallel backends."""

PROCESSES = "processes"
THREADS = "threads"
LOKY = "loky"

PARALLEL_BACKENDS_DEFAULT = ParallelBackend.LOKY # type: ignore[attr-defined]

PARALLEL_BACKENDS = {
ParallelBackend.PROCESSES: CloudpickleProcessPoolExecutor,
ParallelBackend.THREADS: ThreadPoolExecutor,
ParallelBackend.LOKY: get_reusable_executor, # type: ignore[attr-defined]
}
PARALLEL_BACKEND_BUILDER = {
ParallelBackend.PROCESSES: lambda: CloudpickleProcessPoolExecutor,
ParallelBackend.THREADS: lambda: ThreadPoolExecutor,
ParallelBackend.LOKY: lambda: get_reusable_executor,
}
3 changes: 1 addition & 2 deletions src/pytask_parallel/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pytask import EnumChoice
from pytask import hookimpl

from pytask_parallel.backends import PARALLEL_BACKENDS_DEFAULT
from pytask_parallel.backends import ParallelBackend


Expand All @@ -27,7 +26,7 @@ def pytask_extend_command_line_interface(cli: click.Group) -> None:
["--parallel-backend"],
type=EnumChoice(ParallelBackend),
help="Backend for the parallelization.",
default=PARALLEL_BACKENDS_DEFAULT,
default=ParallelBackend.LOKY,
),
]
cli.commands["build"].params.extend(additional_parameters)
12 changes: 10 additions & 2 deletions src/pytask_parallel/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pytask.tree_util import tree_structure
from rich.traceback import Traceback

from pytask_parallel.backends import PARALLEL_BACKENDS
from pytask_parallel.backends import PARALLEL_BACKEND_BUILDER
from pytask_parallel.backends import ParallelBackend

if TYPE_CHECKING:
Expand All @@ -54,6 +54,12 @@ def pytask_post_parse(config: dict[str, Any]) -> None:
else:
config["pm"].register(ProcessesNameSpace)

if PARALLEL_BACKEND_BUILDER[config["parallel_backend"]] is None:
raise
config["_parallel_executor"] = PARALLEL_BACKEND_BUILDER[
config["parallel_backend"]
]()


@hookimpl(tryfirst=True)
def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR0915
Expand All @@ -73,7 +79,9 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
reports = session.execution_reports
running_tasks: dict[str, Future[Any]] = {}

parallel_backend = PARALLEL_BACKENDS[session.config["parallel_backend"]]
parallel_backend = PARALLEL_BACKEND_BUILDER[
session.config["parallel_backend"]
]()

with parallel_backend(max_workers=session.config["n_workers"]) as executor:
session.config["_parallel_executor"] = executor
Expand Down
53 changes: 26 additions & 27 deletions tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pytask import ExitCode
from pytask import build
from pytask import cli
from pytask_parallel.backends import PARALLEL_BACKENDS
from pytask_parallel.backends import ParallelBackend
from pytask_parallel.execute import _Sleeper

Expand All @@ -19,18 +18,18 @@ class Session:


@pytest.mark.end_to_end()
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
def test_parallel_execution(tmp_path, parallel_backend):
source = """
import pytask
from pytask import Product
from pathlib import Path
from typing_extensions import Annotated

@pytask.mark.produces("out_1.txt")
def task_1(produces):
produces.write_text("1")
def task_1(path: Annotated[Path, Product] = Path("out_1.txt")):
path.write_text("1")

@pytask.mark.produces("out_2.txt")
def task_2(produces):
produces.write_text("2")
def task_2(path: Annotated[Path, Product] = Path("out_2.txt")):
path.write_text("2")
"""
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
session = build(paths=tmp_path, n_workers=2, parallel_backend=parallel_backend)
Expand All @@ -41,18 +40,18 @@ def task_2(produces):


@pytest.mark.end_to_end()
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
def test_parallel_execution_w_cli(runner, tmp_path, parallel_backend):
source = """
import pytask
from pytask import Product
from pathlib import Path
from typing_extensions import Annotated

@pytask.mark.produces("out_1.txt")
def task_1(produces):
produces.write_text("1")
def task_1(path: Annotated[Path, Product] = Path("out_1.txt")):
path.write_text("1")

@pytask.mark.produces("out_2.txt")
def task_2(produces):
produces.write_text("2")
def task_2(path: Annotated[Path, Product] = Path("out_2.txt")):
path.write_text("2")
"""
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
result = runner.invoke(
Expand All @@ -71,7 +70,7 @@ def task_2(produces):


@pytest.mark.end_to_end()
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
def test_stop_execution_when_max_failures_is_reached(tmp_path, parallel_backend):
source = """
import time
Expand Down Expand Up @@ -99,7 +98,7 @@ def task_3(): time.sleep(3)


@pytest.mark.end_to_end()
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
def test_task_priorities(tmp_path, parallel_backend):
source = """
import pytask
Expand Down Expand Up @@ -140,7 +139,7 @@ def task_5():


@pytest.mark.end_to_end()
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
@pytest.mark.parametrize("show_locals", [True, False])
def test_rendering_of_tracebacks_with_rich(
runner, tmp_path, parallel_backend, show_locals
Expand Down Expand Up @@ -173,12 +172,12 @@ def task_raising_error():
)
def test_collect_warnings_from_parallelized_tasks(runner, tmp_path, parallel_backend):
source = """
import pytask
from pytask import task
import warnings

for i in range(2):

@pytask.mark.task(id=str(i), kwargs={"produces": f"{i}.txt"})
@task(id=str(i), kwargs={"produces": f"{i}.txt"})
def task_example(produces):
warnings.warn("This is a warning.")
produces.touch()
Expand Down Expand Up @@ -222,7 +221,7 @@ def test_sleeper():


@pytest.mark.end_to_end()
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
def test_task_that_return(runner, tmp_path, parallel_backend):
source = """
from pathlib import Path
Expand All @@ -242,7 +241,7 @@ def task_example() -> Annotated[str, Path("file.txt")]:


@pytest.mark.end_to_end()
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
def test_task_without_path_that_return(runner, tmp_path, parallel_backend):
source = """
from pathlib import Path
Expand All @@ -264,7 +263,7 @@ def test_task_without_path_that_return(runner, tmp_path, parallel_backend):

@pytest.mark.end_to_end()
@pytest.mark.parametrize("flag", ["--pdb", "--trace", "--dry-run"])
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_backend):
tmp_path.joinpath("task_example.py").write_text("def task_example(): pass")
result = runner.invoke(
Expand All @@ -278,7 +277,7 @@ def test_parallel_execution_is_deactivated(runner, tmp_path, flag, parallel_back
@pytest.mark.end_to_end()
@pytest.mark.parametrize("code", ["breakpoint()", "import pdb; pdb.set_trace()"])
@pytest.mark.parametrize(
"parallel_backend", [i for i in PARALLEL_BACKENDS if i != ParallelBackend.THREADS]
"parallel_backend", [i for i in ParallelBackend if i != ParallelBackend.THREADS]
)
def test_raise_error_on_breakpoint(runner, tmp_path, code, parallel_backend):
tmp_path.joinpath("task_example.py").write_text(f"def task_example(): {code}")
Expand All @@ -290,7 +289,7 @@ def test_raise_error_on_breakpoint(runner, tmp_path, code, parallel_backend):


@pytest.mark.end_to_end()
@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS)
@pytest.mark.parametrize("parallel_backend", ParallelBackend)
def test_task_partialed(runner, tmp_path, parallel_backend):
source = """
from pathlib import Path
Expand Down