diff --git a/docs/source/changes.md b/docs/source/changes.md index a950e4d7..11401aa2 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -16,6 +16,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and passed via the command line are relative to CWD and paths in the configuration relative to the config file. - {pull}`555` uses new-style hook wrappers and requires pluggy 1.3 for typing. +- {pull}`557` fixes an issue with `@task(after=...)` in notebooks and terminals. ## 0.4.5 - 2024-01-09 diff --git a/src/_pytask/task_utils.py b/src/_pytask/task_utils.py index 4609d3ee..c5e00521 100644 --- a/src/_pytask/task_utils.py +++ b/src/_pytask/task_utils.py @@ -171,17 +171,16 @@ def _parse_after( if isinstance(after, str): return after if callable(after): - if not hasattr(after, "pytask_meta"): - after.pytask_meta = CollectionMetadata() # type: ignore[attr-defined] - return [after.pytask_meta._id] # type: ignore[attr-defined] + after = [after] if isinstance(after, list): new_after = [] for func in after: if not hasattr(func, "pytask_meta"): - func.pytask_meta = CollectionMetadata() # type: ignore[attr-defined] - new_after.append(func.pytask_meta._id) # type: ignore[attr-defined] + func = task()(func) # noqa: PLW2901 + new_after.append(func.pytask_meta._id) + return new_after msg = ( - "'after' should be an expression string, a task, or a list of class. Got " + "'after' should be an expression string, a task, or a list of tasks. Got " f"{after}, instead." ) raise TypeError(msg) diff --git a/tests/test_task.py b/tests/test_task.py index fdaeee5c..57365896 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1,5 +1,6 @@ from __future__ import annotations +import subprocess import textwrap import pytest @@ -637,12 +638,17 @@ def task_first() -> Annotated[str, Path("out.txt")]: assert "1 Skipped because unchanged" in result.output -def test_task_will_be_executed_after_another_one_with_function(tmp_path): - source = """ +@pytest.mark.end_to_end() +@pytest.mark.parametrize("decorator", ["", "@task"]) +def test_task_will_be_executed_after_another_one_with_function( + runner, tmp_path, decorator +): + source = f""" from pytask import task from pathlib import Path from typing_extensions import Annotated + {decorator} def task_first() -> Annotated[str, Path("out.txt")]: return "Hello, World!" @@ -652,8 +658,35 @@ def task_second(): """ tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) - session = build(paths=tmp_path) + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + + +@pytest.mark.end_to_end() +@pytest.mark.parametrize("decorator", ["", "@task"]) +def test_task_will_be_executed_after_another_one_with_function_session( + tmp_path, decorator +): + source = f""" + from pytask import task, ExitCode, build + from pathlib import Path + from typing_extensions import Annotated + + {decorator} + def task_first() -> Annotated[str, Path("out.txt")]: + return "Hello, World!" + + @task(after=task_first) + def task_second(): + assert Path(__file__).parent.joinpath("out.txt").exists() + + session = build(tasks=[task_first, task_second]) assert session.exit_code == ExitCode.OK + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + + result = subprocess.run(("pytask",), cwd=tmp_path, capture_output=True, check=False) + assert result.returncode == ExitCode.OK def test_raise_error_for_wrong_after_expression(runner, tmp_path):