diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index baab55a..5a7b6a6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -59,14 +59,14 @@ jobs: shell: bash -l {0} run: bash <(curl -s https://codecov.io/bash) -F unit -c - - name: Run integration tests. - shell: bash -l {0} - run: tox -e pytest -- tests -m integration --cov=./ --cov-report=xml -n auto - - - name: Upload coverage reports of integration tests. - if: runner.os == 'Linux' && matrix.python-version == '3.9' - shell: bash -l {0} - run: bash <(curl -s https://codecov.io/bash) -F integration -c + # - name: Run integration tests. + # shell: bash -l {0} + # run: tox -e pytest -- tests -m integration --cov=./ --cov-report=xml -n auto + + # - name: Upload coverage reports of integration tests. + # if: runner.os == 'Linux' && matrix.python-version == '3.9' + # shell: bash -l {0} + # run: bash <(curl -s https://codecov.io/bash) -F integration -c - name: Run end-to-end tests. shell: bash -l {0} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c1b63f..e9e0085 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -72,9 +72,13 @@ repos: --ignore-missing-imports, ] additional_dependencies: [ + cloudpickle, + optree, + pytask==0.4.0rc2, + rich, types-attrs, types-click, - types-setuptools + types-setuptools, ] pass_filenames: false - repo: https://github.com/mgedmin/check-manifest diff --git a/CHANGES.md b/CHANGES.md index 98490fe..92e67e5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -8,6 +8,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and ## 0.4.0 - 2023-xx-xx - {pull}`62` deprecates Python 3.7. +- {pull}`64` aligns pytask-parallel with pytask v0.4.0rc2. ## 0.3.1 - 2023-05-27 diff --git a/environment.yml b/environment.yml index dd1290c..0589e69 100644 --- a/environment.yml +++ b/environment.yml @@ -1,6 +1,7 @@ name: pytask-parallel channels: + - conda-forge/label/pytask_rc - conda-forge - nodefaults @@ -10,16 +11,11 @@ dependencies: - setuptools_scm - toml - # Conda - - anaconda-client - - conda-build - - conda-verify - # Package dependencies - - pytask >=0.3 + - pytask>=0.4.0rc2 - cloudpickle - loky - - pybaum >=0.1.1 + - optree # Misc - black diff --git a/pyproject.toml b/pyproject.toml index 71ea451..180b950 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ convention = "numpy" [tool.pytest.ini_options] # Do not add src since it messes with the loading of pytask-parallel as a plugin. -testpaths = ["test"] +testpaths = ["tests"] markers = [ "wip: Tests that are work-in-progress.", "unit: Flag for unit tests which target mainly a single function.", diff --git a/setup.cfg b/setup.cfg index dfa3975..a6b481d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,8 +27,8 @@ install_requires = click cloudpickle loky - pybaum>=0.1.1 - pytask>=0.3 + optree>=0.9.0 + pytask>=0.4.0rc2 python_requires = >=3.8 include_package_data = True package_dir = =src diff --git a/src/pytask_parallel/backends.py b/src/pytask_parallel/backends.py index 8633a3e..513a04d 100644 --- a/src/pytask_parallel/backends.py +++ b/src/pytask_parallel/backends.py @@ -11,9 +11,7 @@ import cloudpickle -def deserialize_and_run_with_cloudpickle( - fn: Callable[..., Any], kwargs: dict[str, Any] -) -> Any: +def deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any: """Deserialize and execute a function and keyword arguments.""" deserialized_fn = cloudpickle.loads(fn) deserialized_kwargs = cloudpickle.loads(kwargs) @@ -40,34 +38,32 @@ def submit( # type: ignore[override] except ImportError: - class ParallelBackendChoices(enum.Enum): + class ParallelBackend(enum.Enum): """Choices for parallel backends.""" PROCESSES = "processes" THREADS = "threads" + PARALLEL_BACKENDS_DEFAULT = ParallelBackend.PROCESSES + PARALLEL_BACKENDS = { - ParallelBackendChoices.PROCESSES: CloudpickleProcessPoolExecutor, - ParallelBackendChoices.THREADS: ThreadPoolExecutor, + ParallelBackend.PROCESSES: CloudpickleProcessPoolExecutor, + ParallelBackend.THREADS: ThreadPoolExecutor, } else: - class ParallelBackendChoices(enum.Enum): # type: ignore[no-redef] + class ParallelBackend(enum.Enum): # type: ignore[no-redef] """Choices for parallel backends.""" PROCESSES = "processes" THREADS = "threads" LOKY = "loky" - PARALLEL_BACKENDS_DEFAULT = ParallelBackendChoices.PROCESSES + PARALLEL_BACKENDS_DEFAULT = ParallelBackend.LOKY # type: ignore[attr-defined] PARALLEL_BACKENDS = { - ParallelBackendChoices.PROCESSES: CloudpickleProcessPoolExecutor, - ParallelBackendChoices.THREADS: ThreadPoolExecutor, - ParallelBackendChoices.LOKY: ( # type: ignore[attr-defined] - get_reusable_executor - ), + ParallelBackend.PROCESSES: CloudpickleProcessPoolExecutor, + ParallelBackend.THREADS: ThreadPoolExecutor, + ParallelBackend.LOKY: get_reusable_executor, # type: ignore[attr-defined] } - -PARALLEL_BACKENDS_DEFAULT = ParallelBackendChoices.PROCESSES diff --git a/src/pytask_parallel/build.py b/src/pytask_parallel/build.py index 9bf403a..b6542f3 100644 --- a/src/pytask_parallel/build.py +++ b/src/pytask_parallel/build.py @@ -5,7 +5,7 @@ from pytask import EnumChoice from pytask import hookimpl from pytask_parallel.backends import PARALLEL_BACKENDS_DEFAULT -from pytask_parallel.backends import ParallelBackendChoices +from pytask_parallel.backends import ParallelBackend @hookimpl @@ -23,7 +23,7 @@ def pytask_extend_command_line_interface(cli: click.Group) -> None: ), click.Option( ["--parallel-backend"], - type=EnumChoice(ParallelBackendChoices), + type=EnumChoice(ParallelBackend), help="Backend for the parallelization.", default=PARALLEL_BACKENDS_DEFAULT, ), diff --git a/src/pytask_parallel/config.py b/src/pytask_parallel/config.py index 6c4af23..8258cb5 100644 --- a/src/pytask_parallel/config.py +++ b/src/pytask_parallel/config.py @@ -6,7 +6,7 @@ from typing import Any from pytask import hookimpl -from pytask_parallel.backends import ParallelBackendChoices +from pytask_parallel.backends import ParallelBackend @hookimpl @@ -17,12 +17,12 @@ def pytask_parse_config(config: dict[str, Any]) -> None: if ( isinstance(config["parallel_backend"], str) - and config["parallel_backend"] in ParallelBackendChoices._value2member_map_ + and config["parallel_backend"] in ParallelBackend._value2member_map_ ): - config["parallel_backend"] = ParallelBackendChoices(config["parallel_backend"]) + config["parallel_backend"] = ParallelBackend(config["parallel_backend"]) elif ( isinstance(config["parallel_backend"], enum.Enum) - and config["parallel_backend"] in ParallelBackendChoices + and config["parallel_backend"] in ParallelBackend ): pass else: diff --git a/src/pytask_parallel/execute.py b/src/pytask_parallel/execute.py index f3178a0..1f212ba 100644 --- a/src/pytask_parallel/execute.py +++ b/src/pytask_parallel/execute.py @@ -12,20 +12,25 @@ from typing import List import attr -from pybaum.tree_util import tree_map +import cloudpickle from pytask import console from pytask import ExecutionReport from pytask import get_marks from pytask import hookimpl from pytask import Mark from pytask import parse_warning_filter +from pytask import PTask from pytask import remove_internal_traceback_frames_from_exc_info from pytask import Session from pytask import Task from pytask import warning_record_to_str from pytask import WarningReport +from pytask.tree_util import PyTree +from pytask.tree_util import tree_leaves +from pytask.tree_util import tree_map +from pytask.tree_util import tree_structure from pytask_parallel.backends import PARALLEL_BACKENDS -from pytask_parallel.backends import ParallelBackendChoices +from pytask_parallel.backends import ParallelBackend from rich.console import ConsoleOptions from rich.traceback import Traceback @@ -33,7 +38,7 @@ @hookimpl def pytask_post_parse(config: dict[str, Any]) -> None: """Register the parallel backend.""" - if config["parallel_backend"] == ParallelBackendChoices.THREADS: + if config["parallel_backend"] == ParallelBackend.THREADS: config["pm"].register(DefaultBackendNameSpace) else: config["pm"].register(ProcessesNameSpace) @@ -99,12 +104,19 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091 for task_name in list(running_tasks): future = running_tasks[task_name] if future.done(): - warning_reports, task_exception = future.result() - session.warnings.extend(warning_reports) - exc_info = ( - _parse_future_exception(future.exception()) - or task_exception - ) + # An exception was thrown before the task was executed. + if future.exception() is not None: + exc_info = _parse_future_exception(future.exception()) + warning_reports = [] + # A task raised an exception. + else: + warning_reports, task_exception = future.result() + session.warnings.extend(warning_reports) + exc_info = ( + _parse_future_exception(future.exception()) + or task_exception + ) + if exc_info is not None: task = session.dag.nodes[task_name]["task"] newly_collected_reports.append( @@ -165,7 +177,7 @@ class ProcessesNameSpace: @staticmethod @hookimpl(tryfirst=True) - def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None: + def pytask_execute_task(session: Session, task: PTask) -> Future[Any] | None: """Execute a task. Take a task, pickle it and send the bytes over to another process. @@ -174,27 +186,33 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None: if session.config["n_workers"] > 1: kwargs = _create_kwargs_for_task(task) + # Task modules are dynamically loaded and added to `sys.modules`. Thus, + # cloudpickle believes the module of the task function is also importable in + # the child process. We have to register the module as dynamic again, so + # that cloudpickle will pickle it with the function. See cloudpickle#417, + # pytask#373 and pytask#374. + task_module = inspect.getmodule(task.function) + cloudpickle.register_pickle_by_value(task_module) + return session.config["_parallel_executor"].submit( - _unserialize_and_execute_task, + _execute_task, task=task, kwargs=kwargs, show_locals=session.config["show_locals"], console_options=console.options, session_filterwarnings=session.config["filterwarnings"], task_filterwarnings=get_marks(task, "filterwarnings"), - task_short_name=task.short_name, ) return None -def _unserialize_and_execute_task( # noqa: PLR0913 - task: Task, +def _execute_task( # noqa: PLR0913 + task: PTask, kwargs: dict[str, Any], show_locals: bool, console_options: ConsoleOptions, session_filterwarnings: tuple[str, ...], task_filterwarnings: tuple[Mark, ...], - task_short_name: str, ) -> tuple[list[WarningReport], tuple[type[BaseException], BaseException, str] | None]: """Unserialize and execute task. @@ -217,15 +235,33 @@ def _unserialize_and_execute_task( # noqa: PLR0913 warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) try: - task.execute(**kwargs) + out = task.execute(**kwargs) except Exception: # noqa: BLE001 exc_info = sys.exc_info() processed_exc_info = _process_exception( exc_info, show_locals, console_options ) else: + if "return" in task.produces: + structure_out = tree_structure(out) + structure_return = tree_structure(task.produces["return"]) + # strict must be false when none is leaf. + if not structure_return.is_prefix(structure_out, strict=False): + msg = ( + "The structure of the return annotation is not a subtree of " + "the structure of the function return.\n\nFunction return: " + f"{structure_out}\n\nReturn annotation: {structure_return}" + ) + raise ValueError(msg) + + nodes = tree_leaves(task.produces["return"]) + values = structure_return.flatten_up_to(out) + for node, value in zip(nodes, values): + node.save(value) # type: ignore[attr-defined] + processed_exc_info = None + task_display_name = getattr(task, "display_name", task.name) warning_reports = [] for warning_message in log: fs_location = warning_message.filename, warning_message.lineno @@ -233,7 +269,7 @@ def _unserialize_and_execute_task( # noqa: PLR0913 WarningReport( message=warning_record_to_str(warning_message), fs_location=fs_location, - id_=task_short_name, + id_=task_display_name, ) ) @@ -293,15 +329,17 @@ def _mock_processes_for_threads( return [], exc_info -def _create_kwargs_for_task(task: Task) -> dict[Any, Any]: +def _create_kwargs_for_task(task: PTask) -> dict[str, PyTree[Any]]: """Create kwargs for task function.""" - kwargs = {**task.kwargs} + parameters = inspect.signature(task.function).parameters + + kwargs = {} + for name, value in task.depends_on.items(): + kwargs[name] = tree_map(lambda x: x.load(), value) - func_arg_names = set(inspect.signature(task.function).parameters) - for arg_name in ("depends_on", "produces"): - if arg_name in func_arg_names: - attribute = getattr(task, arg_name) - kwargs[arg_name] = tree_map(lambda x: x.value, attribute) + for name, value in task.produces.items(): + if name in parameters: + kwargs[name] = tree_map(lambda x: x.load(), value) return kwargs diff --git a/tests/conftest.py b/tests/conftest.py index 541f8d3..8cba383 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,71 @@ from __future__ import annotations +import sys +from contextlib import contextmanager +from typing import Callable + import pytest from click.testing import CliRunner +class SysPathsSnapshot: + """A snapshot for sys.path.""" + + def __init__(self) -> None: + self.__saved = list(sys.path), list(sys.meta_path) + + def restore(self) -> None: + sys.path[:], sys.meta_path[:] = self.__saved + + +class SysModulesSnapshot: + """A snapshot for sys.modules.""" + + def __init__(self, preserve: Callable[[str], bool] | None = None) -> None: + self.__preserve = preserve + self.__saved = dict(sys.modules) + + def restore(self) -> None: + if self.__preserve: + self.__saved.update( + (k, m) for k, m in sys.modules.items() if self.__preserve(k) + ) + sys.modules.clear() + sys.modules.update(self.__saved) + + +@contextmanager +def restore_sys_path_and_module_after_test_execution(): + sys_path_snapshot = SysPathsSnapshot() + sys_modules_snapshot = SysModulesSnapshot() + yield + sys_modules_snapshot.restore() + sys_path_snapshot.restore() + + +@pytest.fixture(autouse=True) +def _restore_sys_path_and_module_after_test_execution(): + """Restore sys.path and sys.modules after every test execution. + + This fixture became necessary because most task modules in the tests are named + `task_example`. Since the change in #424, the same module is not reimported which + solves errors with parallelization. At the same time, modules with the same name in + the tests are overshadowing another and letting tests fail. + + The changes to `sys.path` might not be necessary to restore, but we do it anyways. + + """ + with restore_sys_path_and_module_after_test_execution(): + yield + + +class CustomCliRunner(CliRunner): + def invoke(self, *args, **kwargs): + """Restore sys.path and sys.modules after an invocation.""" + with restore_sys_path_and_module_after_test_execution(): + return super().invoke(*args, **kwargs) + + @pytest.fixture() def runner(): - return CliRunner() + return CustomCliRunner() diff --git a/tests/test_config.py b/tests/test_config.py index 9f37753..f173a73 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,9 +4,9 @@ import textwrap import pytest +from pytask import build from pytask import ExitCode -from pytask import main -from pytask_parallel.backends import ParallelBackendChoices +from pytask_parallel.backends import ParallelBackend @pytest.mark.end_to_end() @@ -21,7 +21,7 @@ ], ) def test_interplay_between_debugging_and_parallel(tmp_path, pdb, n_workers, expected): - session = main({"paths": tmp_path, "pdb": pdb, "n_workers": n_workers}) + session = build(paths=tmp_path, pdb=pdb, n_workers=n_workers) assert session.config["n_workers"] == expected @@ -36,20 +36,20 @@ def test_interplay_between_debugging_and_parallel(tmp_path, pdb, n_workers, expe ] + [ ("parallel_backend", parallel_backend, ExitCode.OK) - for parallel_backend in ParallelBackendChoices + for parallel_backend in ParallelBackend ], ) def test_reading_values_from_config_file( tmp_path, configuration_option, value, exit_code ): - config_value = value.value if isinstance(value, ParallelBackendChoices) else value + config_value = value.value if isinstance(value, ParallelBackend) else value config = f""" [tool.pytask.ini_options] {configuration_option} = {config_value!r} """ tmp_path.joinpath("pyproject.toml").write_text(textwrap.dedent(config)) - session = main({"paths": tmp_path}) + session = build(paths=tmp_path) assert session.exit_code == exit_code if value == "auto": diff --git a/tests/test_execute.py b/tests/test_execute.py index 96641b4..88f88c2 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -1,20 +1,17 @@ from __future__ import annotations -import pickle import textwrap -from pathlib import Path from time import time import pytest +from pytask import build from pytask import cli from pytask import ExitCode -from pytask import main -from pytask import Task from pytask_parallel.backends import PARALLEL_BACKENDS -from pytask_parallel.backends import ParallelBackendChoices +from pytask_parallel.backends import ParallelBackend from pytask_parallel.execute import _Sleeper -from pytask_parallel.execute import DefaultBackendNameSpace -from pytask_parallel.execute import ProcessesNameSpace + +from tests.conftest import restore_sys_path_and_module_after_test_execution class Session: @@ -23,69 +20,41 @@ class Session: @pytest.mark.end_to_end() @pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS) -def test_parallel_execution_speedup(tmp_path, parallel_backend): +def test_parallel_execution(tmp_path, parallel_backend): source = """ import pytask - import time @pytask.mark.produces("out_1.txt") def task_1(produces): - time.sleep(5) produces.write_text("1") @pytask.mark.produces("out_2.txt") def task_2(produces): - time.sleep(5) produces.write_text("2") """ - tmp_path.joinpath("task_dummy.py").write_text(textwrap.dedent(source)) - - session = main({"paths": tmp_path}) - - assert session.exit_code == ExitCode.OK - assert session.execution_end - session.execution_start > 10 - - tmp_path.joinpath("out_1.txt").unlink() - tmp_path.joinpath("out_2.txt").unlink() - - session = main( - {"paths": tmp_path, "n_workers": 2, "parallel_backend": parallel_backend} - ) - + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + session = build(paths=tmp_path, n_workers=2, parallel_backend=parallel_backend) assert session.exit_code == ExitCode.OK - assert session.execution_end - session.execution_start < 10 + assert len(session.tasks) == 2 + assert tmp_path.joinpath("out_1.txt").exists() + assert tmp_path.joinpath("out_2.txt").exists() @pytest.mark.end_to_end() @pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS) -def test_parallel_execution_speedup_w_cli(runner, tmp_path, parallel_backend): +def test_parallel_execution_w_cli(runner, tmp_path, parallel_backend): source = """ import pytask - import time @pytask.mark.produces("out_1.txt") def task_1(produces): - time.sleep(5) produces.write_text("1") @pytask.mark.produces("out_2.txt") def task_2(produces): - time.sleep(5) produces.write_text("2") """ - tmp_path.joinpath("task_dummy.py").write_text(textwrap.dedent(source)) - - start = time() - result = runner.invoke(cli, [tmp_path.as_posix()]) - end = time() - - assert result.exit_code == ExitCode.OK - assert end - start > 10 - - tmp_path.joinpath("out_1.txt").unlink() - tmp_path.joinpath("out_2.txt").unlink() - - start = time() + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) result = runner.invoke( cli, [ @@ -96,51 +65,9 @@ def task_2(produces): parallel_backend, ], ) - end = time() - assert result.exit_code == ExitCode.OK - assert "Started 2 workers." in result.output - assert end - start < 10 - - -@pytest.mark.integration() -@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS) -def test_pytask_execute_task_w_processes(parallel_backend): - # Local function which cannot be used with multiprocessing. - def myfunc(): - return 1 - - # Verify it cannot be used with multiprocessing because it cannot be pickled. - with pytest.raises(AttributeError): - pickle.dumps(myfunc) - - task = Task(base_name="task_example", path=Path(), function=myfunc) - - session = Session() - session.config = { - "n_workers": 2, - "parallel_backend": parallel_backend, - "show_locals": False, - "filterwarnings": [], - } - - with PARALLEL_BACKENDS[parallel_backend]( - max_workers=session.config["n_workers"] - ) as executor: - session.config["_parallel_executor"] = executor - - backend_name_space = { - ParallelBackendChoices.PROCESSES: ProcessesNameSpace, - ParallelBackendChoices.THREADS: DefaultBackendNameSpace, - ParallelBackendChoices.LOKY: DefaultBackendNameSpace, - }[parallel_backend] - - future = backend_name_space.pytask_execute_task(session, task) - executor.shutdown() - - warning_reports, exception = future.result() - assert warning_reports == [] - assert exception is None + assert tmp_path.joinpath("out_1.txt").exists() + assert tmp_path.joinpath("out_2.txt").exists() @pytest.mark.end_to_end() @@ -156,16 +83,15 @@ def task_2(): time.sleep(2); raise NotImplementedError @pytask.mark.try_last def task_3(): time.sleep(3) """ - tmp_path.joinpath("task_dummy.py").write_text(textwrap.dedent(source)) - - session = main( - { - "paths": tmp_path, - "n_workers": 2, - "parallel_backend": parallel_backend, - "max_failures": 1, - } - ) + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + + with restore_sys_path_and_module_after_test_execution(): + session = build( + paths=tmp_path, + n_workers=2, + parallel_backend=parallel_backend, + max_failures=1, + ) assert session.exit_code == ExitCode.FAILED assert len(session.tasks) == 3 @@ -201,11 +127,10 @@ def task_4(): def task_5(): time.sleep(0.1) """ - tmp_path.joinpath("task_dummy.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) - session = main( - {"paths": tmp_path, "parallel_backend": parallel_backend, "n_workers": 2} - ) + with restore_sys_path_and_module_after_test_execution(): + session = build(paths=tmp_path, parallel_backend=parallel_backend, n_workers=2) assert session.exit_code == ExitCode.OK first_task_name = session.execution_reports[0].task.name @@ -227,7 +152,7 @@ def task_raising_error(): a = list(range(5)) raise Exception """ - tmp_path.joinpath("task_dummy.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) args = [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend] if show_locals: @@ -240,41 +165,11 @@ def task_raising_error(): assert ("[0, 1, 2, 3, 4]" in result.output) is show_locals -@pytest.mark.end_to_end() -@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS) -def test_generators_are_removed_from_depends_on_produces(tmp_path, parallel_backend): - """Only works with pytask >=0.1.9.""" - source = """ - from pathlib import Path - import pytask - - @pytask.mark.parametrize("produces", [ - ((x for x in ["out.txt", "out_2.txt"]),), - ["in.txt"], - ]) - def task_example(produces): - produces = {0: produces} if isinstance(produces, Path) else produces - for p in produces.values(): - p.write_text("hihi") - """ - tmp_path.joinpath("task_dummy.py").write_text(textwrap.dedent(source)) - - session = main( - {"paths": tmp_path, "parallel_backend": parallel_backend, "n_workers": 2} - ) - - assert session.exit_code == ExitCode.OK - - @pytest.mark.end_to_end() @pytest.mark.parametrize( "parallel_backend", # Capturing warnings is not thread-safe. - [ - backend - for backend in PARALLEL_BACKENDS - if backend != ParallelBackendChoices.THREADS - ], + [ParallelBackend.PROCESSES], ) def test_collect_warnings_from_parallelized_tasks(runner, tmp_path, parallel_backend): source = """ @@ -283,7 +178,7 @@ def test_collect_warnings_from_parallelized_tasks(runner, tmp_path, parallel_bac for i in range(2): - @pytask.mark.task(id=i, kwargs={"produces": f"{i}.txt"}) + @pytask.mark.task(id=str(i), kwargs={"produces": f"{i}.txt"}) def task_example(produces): warnings.warn("This is a warning.") produces.touch() @@ -304,6 +199,7 @@ def task_example(produces): assert "task_example.py::task_example[1]" in warnings_block +@pytest.mark.unit() def test_sleeper(): sleeper = _Sleeper(timings=[1, 2, 3], timing_idx=0) @@ -323,3 +219,40 @@ def test_sleeper(): sleeper.sleep() end = time() assert 1 <= end - start <= 2 + + +@pytest.mark.end_to_end() +@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS) +def test_task_that_return(runner, tmp_path, parallel_backend): + source = """ + from pathlib import Path + from typing_extensions import Annotated + + def task_example() -> Annotated[str, Path("file.txt")]: + return "Hello, Darkness, my old friend." + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + result = runner.invoke( + cli, [tmp_path.as_posix(), "--parallel-backend", parallel_backend] + ) + assert result.exit_code == ExitCode.OK + assert tmp_path.joinpath("file.txt").exists() + + +@pytest.mark.end_to_end() +@pytest.mark.parametrize("parallel_backend", PARALLEL_BACKENDS) +def test_task_without_path_that_return(runner, tmp_path, parallel_backend): + source = """ + from pathlib import Path + from pytask import task + + task_example = task( + produces=Path("file.txt") + )(lambda *x: "Hello, Darkness, my old friend.") + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + result = runner.invoke( + cli, [tmp_path.as_posix(), "--parallel-backend", parallel_backend] + ) + assert result.exit_code == ExitCode.OK + assert tmp_path.joinpath("file.txt").exists() diff --git a/tox.ini b/tox.ini index e83a36d..c6a0a69 100644 --- a/tox.ini +++ b/tox.ini @@ -5,16 +5,20 @@ envlist = pytest usedevelop = true [testenv:pytest] +conda_channels = + conda-forge/label/pytask_rc + conda-forge + nodefaults conda_deps = + setuptools_scm + toml + cloudpickle loky - pytask >=0.3.0 + pytask >=0.4.0rc2 pytest pytest-cov pytest-xdist -conda_channels = - conda-forge - nodefaults commands = pip install --no-deps -e . pytest {posargs}