From f97ca9fdd24621770724a589ebb7ddc4013f6a87 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Tue, 19 Dec 2023 22:23:58 +0100 Subject: [PATCH 1/3] Allow task functions to be partialed. --- docs/source/changes.md | 1 + src/_pytask/task_utils.py | 18 +++++++++++++++++- tests/test_task.py | 5 +++-- tests/test_task_utils.py | 24 ++++++++++++++++++++++++ 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/docs/source/changes.md b/docs/source/changes.md index f3fcb050..bb1c74e8 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -22,6 +22,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and - {pull}`525` enables pytask to work with remote files using universal_pathlib. - {pull}`528` improves the codecov setup and coverage. - {pull}`535` reenables and fixes tests with Jupyter. +- {pull}`536` allows partialed functions to be task functions. ## 0.4.4 - 2023-12-04 diff --git a/src/_pytask/task_utils.py b/src/_pytask/task_utils.py index dd4fd269..5ef1102f 100644 --- a/src/_pytask/task_utils.py +++ b/src/_pytask/task_utils.py @@ -1,6 +1,7 @@ """Contains utilities related to the ``@pytask.mark.task`` decorator.""" from __future__ import annotations +import functools import inspect from collections import defaultdict from types import BuiltinFunctionType @@ -115,7 +116,7 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: path = get_file(unwrapped) parsed_kwargs = {} if kwargs is None else kwargs - parsed_name = name if isinstance(name, str) else func.__name__ + parsed_name = _parse_name(unwrapped, name) parsed_after = _parse_after(after) if hasattr(unwrapped, "pytask_meta"): @@ -148,6 +149,21 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: return wrapper +def _parse_name(func: Callable[..., Any], name: str | None) -> str: + """Parse name from task function.""" + if name: + return name + + if isinstance(func, functools.partial): + func = func.func + + if hasattr(func, "__name__"): + return func.__name__ + + msg = "Cannot infer name for task function." + raise NotImplementedError(msg) + + def _parse_after( after: str | Callable[..., Any] | list[Callable[..., Any]] | None, ) -> str | list[Callable[..., Any]]: diff --git a/tests/test_task.py b/tests/test_task.py index 3f427ce9..b1086198 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -349,8 +349,9 @@ def func(content): result = runner.invoke(cli, [tmp_path.as_posix()]) - assert result.exit_code == ExitCode.COLLECTION_FAILED - assert "1 Collected errors and tasks" in result.output + assert result.exit_code == ExitCode.OK + assert "1 Succeeded" in result.output + assert tmp_path.joinpath("out.txt").read_text() == "hello" @pytest.mark.end_to_end() diff --git a/tests/test_task_utils.py b/tests/test_task_utils.py index 36b2b208..6fcd315c 100644 --- a/tests/test_task_utils.py +++ b/tests/test_task_utils.py @@ -1,10 +1,12 @@ from __future__ import annotations from contextlib import ExitStack as does_not_raise # noqa: N813 +from functools import partial from typing import NamedTuple import pytest from _pytask.task_utils import _arg_value_to_id_component +from _pytask.task_utils import _parse_name from _pytask.task_utils import _parse_task_kwargs from attrs import define @@ -56,3 +58,25 @@ def test_parse_task_kwargs(kwargs, expectation, expected): with expectation: result = _parse_task_kwargs(kwargs) assert result == expected + + +def task_func(x): # noqa: ARG001 # pragma: no cover + pass + + +@pytest.mark.unit() +@pytest.mark.parametrize( + ("func", "name", "expectation", "expected"), + [ + (task_func, None, does_not_raise(), "task_func"), + (task_func, "name", does_not_raise(), "name"), + (partial(task_func, x=1), None, does_not_raise(), "task_func"), + (partial(task_func, x=1), "name", does_not_raise(), "name"), + (lambda x: None, None, does_not_raise(), ""), # noqa: ARG005 + (partial(lambda x: None, x=1), None, does_not_raise(), ""), # noqa: ARG005 + ], +) +def test_parse_name(func, name, expectation, expected): + with expectation: + result = _parse_name(func, name) + assert result == expected From bd96374911f4f174a0789f7a2e4b9a088c856dff Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Tue, 19 Dec 2023 22:25:49 +0100 Subject: [PATCH 2/3] Add failing test. --- tests/test_task_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_task_utils.py b/tests/test_task_utils.py index 6fcd315c..666af242 100644 --- a/tests/test_task_utils.py +++ b/tests/test_task_utils.py @@ -74,6 +74,7 @@ def task_func(x): # noqa: ARG001 # pragma: no cover (partial(task_func, x=1), "name", does_not_raise(), "name"), (lambda x: None, None, does_not_raise(), ""), # noqa: ARG005 (partial(lambda x: None, x=1), None, does_not_raise(), ""), # noqa: ARG005 + (1, None, pytest.raises(NotImplementedError, match="Cannot"), None), ], ) def test_parse_name(func, name, expectation, expected): From 63376334fc9b911acb398c7f5560b6fea5ae3cd4 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Tue, 19 Dec 2023 22:52:12 +0100 Subject: [PATCH 3/3] fix. --- tox.ini | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tox.ini b/tox.ini index 8adee493..fb088034 100644 --- a/tox.ini +++ b/tox.ini @@ -4,17 +4,15 @@ envlist = docs, test [testenv] passenv = CI -package = wheel +package = editable [testenv:test] extras = all, test deps = pygraphviz;platform_system != "Windows" - commands = pytest --nbmake {posargs} - [testenv:docs] extras = docs, test commands =