From 4f9424f97a59d8c8e386d005040d0e191f38288d Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Fri, 12 Apr 2024 19:38:39 +0200 Subject: [PATCH 1/5] Stop unwrapping coiled functions. --- docs/source/changes.md | 1 + src/_pytask/collect.py | 5 ++--- src/_pytask/shared.py | 26 ++++++++++++++++++++++++++ src/_pytask/task_utils.py | 5 +++-- tests/test_shared.py | 32 ++++++++++++++++++++++++++++++++ 5 files changed, 64 insertions(+), 5 deletions(-) diff --git a/docs/source/changes.md b/docs/source/changes.md index 34b8abae..92ce12e5 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -39,6 +39,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and - {pull}`593` recreate `PythonNode`s every run since they carry the `_NoDefault` enum as the value whose state is `None`. - {pull}`594` publishes `NodeLoadError`. +- {pull}`595` stops unwrapping task functions until a `coiled.function.Function`. ## 0.4.7 - 2024-03-19 diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index 1b1f9a2e..e118e9f7 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -46,6 +46,7 @@ from _pytask.reports import CollectionReport from _pytask.shared import find_duplicates from _pytask.shared import to_list +from _pytask.shared import unwrap_task_function from _pytask.task_utils import COLLECTED_TASKS from _pytask.task_utils import task as task_decorator from _pytask.typing import is_task_function @@ -317,9 +318,7 @@ def pytask_collect_task( obj.pytask_meta.is_generator if hasattr(obj, "pytask_meta") else False ) - # Get the underlying function to avoid having different states of the function, - # e.g. due to pytask_meta, in different layers of the wrapping. - unwrapped = inspect.unwrap(obj) + unwrapped = unwrap_task_function(obj) if path is None: return TaskWithoutPath( diff --git a/src/_pytask/shared.py b/src/_pytask/shared.py index 2174b2a7..ac3e43f8 100644 --- a/src/_pytask/shared.py +++ b/src/_pytask/shared.py @@ -3,13 +3,16 @@ from __future__ import annotations import glob +import inspect from pathlib import Path from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import Sequence import click +from attrs import define from _pytask.console import format_node_name from _pytask.console import format_task_name @@ -23,6 +26,18 @@ import networkx as nx +try: + from coiled.function import Function as CoiledFunction +except ImportError: + + @define + class CoiledFunction: # type: ignore[no-redef] + cluster_kwargs: dict[str, Any] + environ: dict[str, Any] + function: Callable[..., Any] | None + keepalive: str | None + + __all__ = [ "convert_to_enum", "find_duplicates", @@ -30,6 +45,7 @@ "parse_paths", "reduce_names_of_multiple_nodes", "to_list", + "unwrap_task_function", ] @@ -146,3 +162,13 @@ def convert_to_enum(value: Any, enum: type[Enum]) -> Enum: values = [e.value for e in enum] msg = f"Value {value!r} is not a valid {enum!r}. Valid values are {values}." raise ValueError(msg) from None + + +def unwrap_task_function(obj: Any) -> Callable[..., Any]: + """Unwrap a task function. + + Get the underlying function to avoid having different states of the function, e.g. + due to pytask_meta, in different layers of the wrapping. + + """ + return inspect.unwrap(obj, stop=lambda x: isinstance(x, CoiledFunction)) diff --git a/src/_pytask/task_utils.py b/src/_pytask/task_utils.py index a9f556c4..3984286a 100644 --- a/src/_pytask/task_utils.py +++ b/src/_pytask/task_utils.py @@ -16,6 +16,7 @@ from _pytask.mark import Mark from _pytask.models import CollectionMetadata from _pytask.shared import find_duplicates +from _pytask.shared import unwrap_task_function from _pytask.typing import is_task_function if TYPE_CHECKING: @@ -117,7 +118,7 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: ) raise ValueError(msg) - unwrapped = inspect.unwrap(func) + unwrapped = unwrap_task_function(func) # We do not allow builtins as functions because we would need to use # ``inspect.stack`` to infer their caller location and they are unable to carry @@ -145,7 +146,7 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: unwrapped.pytask_meta.produces = produces unwrapped.pytask_meta.after = parsed_after else: - unwrapped.pytask_meta = CollectionMetadata( + unwrapped.pytask_meta = CollectionMetadata( # type: ignore[attr-defined] after=parsed_after, is_generator=is_generator, id_=id, diff --git a/tests/test_shared.py b/tests/test_shared.py index 6408950a..545abb5a 100644 --- a/tests/test_shared.py +++ b/tests/test_shared.py @@ -1,11 +1,13 @@ from __future__ import annotations +import functools import textwrap from contextlib import ExitStack as does_not_raise # noqa: N813 import pytest from _pytask.shared import convert_to_enum from _pytask.shared import find_duplicates +from _pytask.shared import unwrap_task_function from pytask import ExitCode from pytask import ShowCapture from pytask import build @@ -49,3 +51,33 @@ def test_convert_to_enum(value, enum, expectation, expected): with expectation: result = convert_to_enum(value, enum) assert result == expected + + +@pytest.mark.unit() +def test_unwrap_task_function(): + def task(): + pass + + # partialed functions are only unwrapped after wraps. + partialed = functools.wraps(task)(functools.partial(task)) + assert unwrap_task_function(partialed) is task + + partialed = functools.partial(task) + assert unwrap_task_function(partialed) is partialed + + def decorator(func): + @functools.wraps(func) + def wrapper(): + return func() + + return wrapper + + decorated = decorator(task) + assert unwrap_task_function(decorated) is task + + from _pytask.shared import CoiledFunction + + coiled_function = functools.wraps(task)( + CoiledFunction(function=task, cluster_kwargs={}, environ={}, keepalive=None) + ) + assert unwrap_task_function(coiled_function) is coiled_function From 89b77441fb3b4c2ae29b4efa891a185c11c6b2c1 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Fri, 12 Apr 2024 19:50:59 +0200 Subject: [PATCH 2/5] fix. --- src/_pytask/shared.py | 1 + tests/test_shared.py | 7 ------- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/_pytask/shared.py b/src/_pytask/shared.py index ac3e43f8..099054e3 100644 --- a/src/_pytask/shared.py +++ b/src/_pytask/shared.py @@ -38,6 +38,7 @@ class CoiledFunction: # type: ignore[no-redef] keepalive: str | None + __all__ = [ "convert_to_enum", "find_duplicates", diff --git a/tests/test_shared.py b/tests/test_shared.py index 545abb5a..5cfbe348 100644 --- a/tests/test_shared.py +++ b/tests/test_shared.py @@ -74,10 +74,3 @@ def wrapper(): decorated = decorator(task) assert unwrap_task_function(decorated) is task - - from _pytask.shared import CoiledFunction - - coiled_function = functools.wraps(task)( - CoiledFunction(function=task, cluster_kwargs={}, environ={}, keepalive=None) - ) - assert unwrap_task_function(coiled_function) is coiled_function From 512e0a77d351999e043a0b281e953bd059e91a72 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Fri, 12 Apr 2024 19:51:25 +0200 Subject: [PATCH 3/5] fix. --- src/_pytask/shared.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/_pytask/shared.py b/src/_pytask/shared.py index 099054e3..ac3e43f8 100644 --- a/src/_pytask/shared.py +++ b/src/_pytask/shared.py @@ -38,7 +38,6 @@ class CoiledFunction: # type: ignore[no-redef] keepalive: str | None - __all__ = [ "convert_to_enum", "find_duplicates", From a4e2a72bce971da965a4f1da5c1b8e8c85ebf7b5 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 13 Apr 2024 01:52:40 +0200 Subject: [PATCH 4/5] Fix. --- .gitignore | 3 +++ src/_pytask/coiled_utils.py | 31 +++++++++++++++++++++++++++++++ src/_pytask/collect.py | 32 +++++++++++++++++--------------- src/_pytask/models.py | 1 + src/_pytask/shared.py | 16 ++-------------- src/_pytask/task_utils.py | 10 ++++++++++ 6 files changed, 64 insertions(+), 29 deletions(-) create mode 100644 src/_pytask/coiled_utils.py diff --git a/.gitignore b/.gitignore index d5e3ea27..50159a20 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ src/_pytask/_version.py *.pkl tests/test_jupyter/*.txt +.mypy_cache +.pytest_cache +.ruff_cache diff --git a/src/_pytask/coiled_utils.py b/src/_pytask/coiled_utils.py new file mode 100644 index 00000000..41614681 --- /dev/null +++ b/src/_pytask/coiled_utils.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Any +from typing import Callable + +from attrs import define + +try: + from coiled.function import Function +except ImportError: + + @define + class Function: # type: ignore[no-redef] + cluster_kwargs: dict[str, Any] + environ: dict[str, Any] + function: Callable[..., Any] | None + keepalive: str | None + + +__all__ = ["Function"] + + +def extract_coiled_function_kwargs(func: Function) -> dict[str, Any]: + """Extract the kwargs for a coiled function.""" + return { + "cluster_kwargs": func._cluster_kwargs, + "keepalive": func.keepalive, + "environ": func._environ, + "local": func._local, + "name": func._name, + } diff --git a/src/_pytask/collect.py b/src/_pytask/collect.py index e118e9f7..055a80d5 100644 --- a/src/_pytask/collect.py +++ b/src/_pytask/collect.py @@ -16,6 +16,8 @@ from rich.text import Text from upath import UPath +from _pytask.coiled_utils import Function +from _pytask.coiled_utils import extract_coiled_function_kwargs from _pytask.collect_utils import create_name_of_python_node from _pytask.collect_utils import parse_dependencies_from_task_function from _pytask.collect_utils import parse_products_from_task_function @@ -312,13 +314,21 @@ def pytask_collect_task( ) markers = get_all_marks(obj) - collection_id = obj.pytask_meta._id if hasattr(obj, "pytask_meta") else None - after = obj.pytask_meta.after if hasattr(obj, "pytask_meta") else [] - is_generator = ( - obj.pytask_meta.is_generator if hasattr(obj, "pytask_meta") else False - ) + + if hasattr(obj, "pytask_meta"): + attributes = { + **obj.pytask_meta.attributes, + "collection_id": obj.pytask_meta._id, + "after": obj.pytask_meta.after, + "is_generator": obj.pytask_meta.is_generator, + } + else: + attributes = {"collection_id": None, "after": [], "is_generator": False} unwrapped = unwrap_task_function(obj) + if isinstance(unwrapped, Function): + attributes["coiled_kwargs"] = extract_coiled_function_kwargs(unwrapped) + unwrapped = unwrap_task_function(unwrapped.function) if path is None: return TaskWithoutPath( @@ -327,11 +337,7 @@ def pytask_collect_task( depends_on=dependencies, produces=products, markers=markers, - attributes={ - "collection_id": collection_id, - "after": after, - "is_generator": is_generator, - }, + attributes=attributes, ) return Task( base_name=name, @@ -340,11 +346,7 @@ def pytask_collect_task( depends_on=dependencies, produces=products, markers=markers, - attributes={ - "collection_id": collection_id, - "after": after, - "is_generator": is_generator, - }, + attributes=attributes, ) if isinstance(obj, PTask): return obj diff --git a/src/_pytask/models.py b/src/_pytask/models.py index afd28eda..791dbb06 100644 --- a/src/_pytask/models.py +++ b/src/_pytask/models.py @@ -50,6 +50,7 @@ class CollectionMetadata: """ after: str | list[Callable[..., Any]] = field(factory=list) + attributes: dict[str, Any] = field(factory=dict) is_generator: bool = False id_: str | None = None kwargs: dict[str, Any] = field(factory=dict) diff --git a/src/_pytask/shared.py b/src/_pytask/shared.py index ac3e43f8..b7a2f491 100644 --- a/src/_pytask/shared.py +++ b/src/_pytask/shared.py @@ -12,8 +12,8 @@ from typing import Sequence import click -from attrs import define +from _pytask.coiled_utils import Function from _pytask.console import format_node_name from _pytask.console import format_task_name from _pytask.node_protocols import PNode @@ -26,18 +26,6 @@ import networkx as nx -try: - from coiled.function import Function as CoiledFunction -except ImportError: - - @define - class CoiledFunction: # type: ignore[no-redef] - cluster_kwargs: dict[str, Any] - environ: dict[str, Any] - function: Callable[..., Any] | None - keepalive: str | None - - __all__ = [ "convert_to_enum", "find_duplicates", @@ -171,4 +159,4 @@ def unwrap_task_function(obj: Any) -> Callable[..., Any]: due to pytask_meta, in different layers of the wrapping. """ - return inspect.unwrap(obj, stop=lambda x: isinstance(x, CoiledFunction)) + return inspect.unwrap(obj, stop=lambda x: isinstance(x, Function)) diff --git a/src/_pytask/task_utils.py b/src/_pytask/task_utils.py index 3984286a..7ee00178 100644 --- a/src/_pytask/task_utils.py +++ b/src/_pytask/task_utils.py @@ -12,6 +12,8 @@ import attrs +from _pytask.coiled_utils import Function +from _pytask.coiled_utils import extract_coiled_function_kwargs from _pytask.console import get_file from _pytask.mark import Mark from _pytask.models import CollectionMetadata @@ -119,6 +121,11 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: raise ValueError(msg) unwrapped = unwrap_task_function(func) + if isinstance(unwrapped, Function): + coiled_kwargs = extract_coiled_function_kwargs(unwrapped) + unwrapped = unwrap_task_function(unwrapped.function) + else: + coiled_kwargs = None # We do not allow builtins as functions because we would need to use # ``inspect.stack`` to infer their caller location and they are unable to carry @@ -156,6 +163,9 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: produces=produces, ) + if coiled_kwargs and hasattr(unwrapped, "pytask_meta"): + unwrapped.pytask_meta.attributes["coiled_kwargs"] = coiled_kwargs + # Store it in the global variable ``COLLECTED_TASKS`` to avoid garbage # collection when the function definition is overwritten in a loop. COLLECTED_TASKS[path].append(unwrapped) From d31c201a15ede0b37cd7f908dbb105d8491d7320 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 18 Apr 2024 23:07:23 +0200 Subject: [PATCH 5/5] Add test.! --- pyproject.toml | 45 ++++++++++++++++++++++---------------------- tests/test_shared.py | 4 ++++ 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5fcd58d4..96e57f33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,18 +21,18 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ - "attrs>=21.3", - "click", - "click-default-group", - "networkx>=2.4", - "optree>=0.9", - "packaging", - "pluggy>=1.3.0", - "rich", - "sqlalchemy>=2", - 'tomli>=1; python_version < "3.11"', - 'typing-extensions; python_version < "3.9"', - "universal-pathlib>=0.2.2", + "attrs>=21.3", + "click", + "click-default-group", + "networkx>=2.4", + "optree>=0.9", + "packaging", + "pluggy>=1.3.0", + "rich", + "sqlalchemy>=2", + 'tomli>=1; python_version < "3.11"', + 'typing-extensions; python_version < "3.9"', + "universal-pathlib>=0.2.2", ] [project.readme] @@ -61,16 +61,17 @@ docs = [ "sphinxext-opengraph", ] test = [ - "deepdiff", - "nbmake", - "pexpect", - "pytest", - "pytest-cov", - "pytest-xdist", - "syrupy", - # For HTTPPath tests. - "aiohttp", - "requests", + "deepdiff", + "nbmake", + "pexpect", + "pytest", + "pytest-cov", + "pytest-xdist", + "syrupy", + # For HTTPPath tests. + "aiohttp", + "requests", + "coiled", ] [project.urls] diff --git a/tests/test_shared.py b/tests/test_shared.py index 5cfbe348..882daba5 100644 --- a/tests/test_shared.py +++ b/tests/test_shared.py @@ -4,6 +4,7 @@ import textwrap from contextlib import ExitStack as does_not_raise # noqa: N813 +import coiled import pytest from _pytask.shared import convert_to_enum from _pytask.shared import find_duplicates @@ -74,3 +75,6 @@ def wrapper(): decorated = decorator(task) assert unwrap_task_function(decorated) is task + + coiled_decorated = coiled.function()(task) + assert unwrap_task_function(coiled_decorated) is coiled_decorated