diff --git a/CHANGES.rst b/CHANGES.rst index 5ee5a50..0c04115 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -9,7 +9,8 @@ all releases are available on `Anaconda.org = 1 or 'auto'.") + + return value + + +def parallel_backend_callback(value): + if value not in ["processes", "threads", None]: + raise click.UsageError("parallel_backend has to be 'processes' or 'threads'.") + return value + + +def delay_click_callback(ctx, name, value): # noqa: U100 + return delay_callback(value) + + +def delay_callback(value): + error_occurred = False + if isinstance(value, float) and 0 < value: + pass + elif value is None: + pass + else: + try: + value = float(value) + except ValueError: + error_occurred = True + else: + if value < 0: + error_occurred = True + + if error_occurred: + raise click.UsageError("delay has to be a number greater than 0.") + + return value diff --git a/src/pytask_parallel/cli.py b/src/pytask_parallel/cli.py index d3e10ac..bb1fec8 100644 --- a/src/pytask_parallel/cli.py +++ b/src/pytask_parallel/cli.py @@ -1,38 +1,34 @@ import click from _pytask.config import hookimpl +from pytask_parallel.callbacks import delay_click_callback +from pytask_parallel.callbacks import n_workers_click_callback @hookimpl def pytask_add_parameters_to_cli(command): additional_parameters = [ click.Option( - ["-n", "--n-processes"], + ["-n", "--n-workers"], help=( "Max. number of pytask_parallel tasks. Integer >= 1 or 'auto' which is " "os.cpu_count() - 1. [default: 1 (no parallelization)]" ), - callback=_validate_n_workers, - metavar="", + metavar="[INTEGER|auto]", + callback=n_workers_click_callback, ), click.Option( - ["--pytask_parallel-backend"], + ["--parallel-backend"], type=click.Choice(["processes", "threads"]), help="Backend for the parallelization. [default: processes]", ), click.Option( ["--delay"], - type=float, help=( "Delay between checking whether tasks have finished. [default: 0.1 " "(seconds)]" ), + metavar="NUMBER > 0", + callback=delay_click_callback, ), ] command.params.extend(additional_parameters) - - -def _validate_n_workers(ctx, param, value): # noqa: U100 - if (isinstance(value, int) and value >= 1) or value == "auto": - pass - else: - raise click.UsageError("n-processes can either be an integer >= 1 or 'auto'.") diff --git a/src/pytask_parallel/config.py b/src/pytask_parallel/config.py index 865bd4b..e06050f 100644 --- a/src/pytask_parallel/config.py +++ b/src/pytask_parallel/config.py @@ -2,26 +2,37 @@ from _pytask.config import hookimpl from _pytask.shared import get_first_not_none_value +from pytask_parallel.callbacks import n_workers_callback +from pytask_parallel.callbacks import parallel_backend_callback @hookimpl def pytask_parse_config(config, config_from_cli, config_from_file): config["n_workers"] = get_first_not_none_value( - config_from_cli, config_from_file, key="n_workers", default=1 + config_from_cli, + config_from_file, + key="n_workers", + default=1, + callback=n_workers_callback, ) if config["n_workers"] == "auto": config["n_workers"] = max(os.cpu_count() - 1, 1) + config["delay"] = get_first_not_none_value( - config_from_cli, config_from_file, key="delay", default=0.1 + config_from_cli, config_from_file, key="delay", default=0.1, callback=float ) config["parallel_backend"] = get_first_not_none_value( - config_from_cli, config_from_file, key="parallel_backend", default="processes" + config_from_cli, + config_from_file, + key="parallel_backend", + default="processes", + callback=parallel_backend_callback, ) @hookimpl def pytask_post_parse(config): - # Disable parallelization if debugging is enabled. + """Disable parallelization if debugging is enabled.""" if config["pdb"] or config["trace"]: config["n_workers"] = 1 diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py new file mode 100644 index 0000000..2b2fc7e --- /dev/null +++ b/tests/test_callbacks.py @@ -0,0 +1,69 @@ +import functools +from contextlib import ExitStack as does_not_raise # noqa: N813 + +import click +import pytest +from pytask_parallel.callbacks import delay_callback +from pytask_parallel.callbacks import delay_click_callback +from pytask_parallel.callbacks import n_workers_callback +from pytask_parallel.callbacks import n_workers_click_callback +from pytask_parallel.callbacks import parallel_backend_callback + + +partialed_n_workers_callback = functools.partial( + n_workers_click_callback, ctx=None, name=None +) + + +@pytest.mark.unit +@pytest.mark.parametrize( + "value, expectation", + [ + (0, pytest.raises(click.UsageError)), + (1, does_not_raise()), + (2, does_not_raise()), + ("auto", does_not_raise()), + ("asdad", pytest.raises(click.UsageError)), + (None, does_not_raise()), + ], +) +@pytest.mark.parametrize("func", [n_workers_callback, partialed_n_workers_callback]) +def test_n_workers_callback(func, value, expectation): + with expectation: + func(value=value) + + +@pytest.mark.unit +@pytest.mark.parametrize( + "value, expectation", + [ + ("threads", does_not_raise()), + ("processes", does_not_raise()), + (1, pytest.raises(click.UsageError)), + ("asdad", pytest.raises(click.UsageError)), + (None, does_not_raise()), + ], +) +def test_parallel_backend_callback(value, expectation): + with expectation: + parallel_backend_callback(value) + + +partialed_delay_callback = functools.partial(delay_click_callback, ctx=None, name=None) + + +@pytest.mark.unit +@pytest.mark.parametrize( + "value, expectation", + [ + (-1, pytest.raises(click.UsageError)), + (0.1, does_not_raise()), + (1, does_not_raise()), + ("asdad", pytest.raises(click.UsageError)), + (None, does_not_raise()), + ], +) +@pytest.mark.parametrize("func", [delay_callback, partialed_delay_callback]) +def test_delay_callback(func, value, expectation): + with expectation: + func(value=value) diff --git a/tests/test_cli.py b/tests/test_cli.py index 4aec14a..26f90aa 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,20 +1,18 @@ -from contextlib import ExitStack as does_not_raise # noqa: N813 +import subprocess +import textwrap -import click import pytest -from pytask_parallel.cli import _validate_n_workers -@pytest.mark.unit -@pytest.mark.parametrize( - "value, expectation", - [ - (0, pytest.raises(click.UsageError)), - (1, does_not_raise()), - (2, does_not_raise()), - ("auto", does_not_raise()), - ], -) -def test_validate_n_workers(value, expectation): - with expectation: - _validate_n_workers(None, None, value) +@pytest.mark.end_to_end +def test_delay_via_cli(tmp_path): + source = """ + import pytask + + @pytask.mark.produces("out_1.txt") + def task_1(produces): + produces.write_text("1") + """ + tmp_path.joinpath("task_dummy.py").write_text(textwrap.dedent(source)) + + subprocess.run(["pytask", tmp_path.as_posix(), "-n", "2", "--delay", "5"]) diff --git a/tests/test_config.py b/tests/test_config.py index 7b73c64..e1180c8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,5 @@ import os +import textwrap import pytest from pytask import main @@ -18,3 +19,36 @@ def test_interplay_between_debugging_and_parallel(tmp_path, pdb, n_workers, expected): session = main({"paths": tmp_path, "pdb": pdb, "n_workers": n_workers}) assert session.config["n_workers"] == expected + + +@pytest.mark.end_to_end +@pytest.mark.parametrize("config_file", ["pytask.ini", "tox.ini", "setup.cfg"]) +@pytest.mark.parametrize( + "name, value", + [ + ("n_workers", "auto"), + ("n_workers", 1), + ("n_workers", 2), + ("delay", 0.1), + ("delay", 1), + ("parallel_backend", "threads"), + ("parallel_backend", "processes"), + ("parallel_backend", "unknown_backend"), + ], +) +def test_reading_values_from_config_file(tmp_path, config_file, name, value): + config = f""" + [pytask] + {name} = {value} + """ + tmp_path.joinpath(config_file).write_text(textwrap.dedent(config)) + + try: + session = main({"paths": tmp_path}) + except Exception as e: + assert "Error while configuring pytask" in str(e) + else: + assert session.exit_code == 0 + if value == "auto": + value = os.cpu_count() - 1 + assert session.config[name] == value diff --git a/tests/test_execute.py b/tests/test_execute.py index de1e7d2..7eeb836 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -94,3 +94,31 @@ def myfunc(): executor.shutdown() assert future.result() is None + + +@pytest.mark.end_to_end +@pytest.mark.parametrize("parallel_backend", ["processes", "threads"]) +def test_parallel_execution_delay(tmp_path, parallel_backend): + source = """ + import pytask + + @pytask.mark.produces("out_1.txt") + def task_1(produces): + produces.write_text("1") + + @pytask.mark.produces("out_2.txt") + def task_2(produces): + produces.write_text("2") + """ + tmp_path.joinpath("task_dummy.py").write_text(textwrap.dedent(source)) + + session = main( + { + "paths": tmp_path, + "delay": 3, + "n_workers": 2, + "parallel_backend": parallel_backend, + } + ) + + assert 3 < session.execution_end - session.execution_start < 10