diff --git a/docs/source/changes.md b/docs/source/changes.md index e3b5a07e..75756f90 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -22,6 +22,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and remains the same, the consecutive tasks are not executed. It remained from when pytask relied on timestamps. - {pull}`497` removes unnecessary code in the collection of tasks. +- {pull}`498` fixes an error when using {class}`~pytask.Task` and + {class}`~pytask.TaskWithoutPath` in task modules. ## 0.4.2 - 2023-11-08 diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index 68170283..2f3f2762 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -275,7 +275,7 @@ def pytask_collect_task( produces=products, markers=markers, ) - if isinstance(obj, PTask): + if isinstance(obj, PTask) and not inspect.isclass(obj): return obj return None diff --git a/src/_pytask/mark_utils.py b/src/_pytask/mark_utils.py index cb6ba4e5..03097e40 100644 --- a/src/_pytask/mark_utils.py +++ b/src/_pytask/mark_utils.py @@ -5,6 +5,7 @@ """ from __future__ import annotations +import inspect from typing import Any from typing import TYPE_CHECKING @@ -18,7 +19,7 @@ def get_all_marks(obj_or_task: Any | PTask) -> list[Mark]: """Get all marks from a callable or task.""" - if isinstance(obj_or_task, PTask): + if isinstance(obj_or_task, PTask) and not inspect.isclass(obj_or_task): marks = obj_or_task.markers else: obj = obj_or_task diff --git a/tests/test_collect.py b/tests/test_collect.py index 78902669..a3c976c5 100644 --- a/tests/test_collect.py +++ b/tests/test_collect.py @@ -679,3 +679,14 @@ def task_mixed(): pass assert result.exit_code == ExitCode.COLLECTION_FAILED assert "Could not collect" in result.output assert "The task cannot have" in result.output + + +@pytest.mark.end_to_end() +def test_module_can_be_collected(runner, tmp_path): + source = """ + from pytask import Task, TaskWithoutPath + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK diff --git a/tests/test_execute.py b/tests/test_execute.py index c1107981..a79dc916 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -1008,3 +1008,43 @@ def func(path): session = build(tasks=[task]) assert session.exit_code == ExitCode.OK assert tmp_path.joinpath("out.txt").exists() + + +def test_collect_task(runner, tmp_path): + source = """ + from pytask import Task, PathNode + from pathlib import Path + + def func(path): path.touch() + + task_create_file = Task( + base_name="task", + function=func, + path=Path(__file__), + produces={"path": PathNode(path=Path(__file__).parent / "out.txt")}, + ) + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert tmp_path.joinpath("out.txt").exists() + + +def test_collect_task_without_path(runner, tmp_path): + source = """ + from pytask import TaskWithoutPath, PathNode + from pathlib import Path + + def func(path): path.touch() + + task_create_file = TaskWithoutPath( + name="task", + function=func, + produces={"path": PathNode(path=Path(__file__).parent / "out.txt")}, + ) + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + assert result.exit_code == ExitCode.OK + assert tmp_path.joinpath("out.txt").exists()