Skip to content

Commit dd2ab42

Browse files
authored
Stop unwrapping coiled functions. (#595)
1 parent d313085 commit dd2ab42

File tree

9 files changed

+133
-40
lines changed

9 files changed

+133
-40
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,7 @@ src/_pytask/_version.py
2222
*.pkl
2323

2424
tests/test_jupyter/*.txt
25+
.mypy_cache
26+
.pytest_cache
27+
.ruff_cache
2528
.venv

docs/source/changes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
3939
- {pull}`593` recreate `PythonNode`s every run since they carry the `_NoDefault` enum as
4040
the value whose state is `None`.
4141
- {pull}`594` publishes `NodeLoadError`.
42+
- {pull}`595` stops unwrapping task functions until a `coiled.function.Function`.
4243
- {pull}`596` add project management with rye.
4344
- {pull}`598` replaces requests with httpx.
4445

pyproject.toml

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@ classifiers = [
2121
]
2222
dynamic = ["version"]
2323
dependencies = [
24-
"attrs>=21.3",
25-
"click",
26-
"click-default-group",
27-
"networkx>=2.4",
28-
"optree>=0.9",
29-
"packaging",
30-
"pluggy>=1.3.0",
31-
"rich",
32-
"sqlalchemy>=2",
33-
'tomli>=1; python_version < "3.11"',
34-
'typing-extensions; python_version < "3.9"',
35-
"universal-pathlib>=0.2.2",
24+
"attrs>=21.3",
25+
"click",
26+
"click-default-group",
27+
"networkx>=2.4",
28+
"optree>=0.9",
29+
"packaging",
30+
"pluggy>=1.3.0",
31+
"rich",
32+
"sqlalchemy>=2",
33+
'tomli>=1; python_version < "3.11"',
34+
'typing-extensions; python_version < "3.9"',
35+
"universal-pathlib>=0.2.2",
3636
]
3737

3838
[project.readme]
@@ -61,14 +61,16 @@ docs = [
6161
"sphinxext-opengraph",
6262
]
6363
test = [
64-
"deepdiff",
65-
"nbmake",
66-
"pexpect",
67-
"pytest",
68-
"pytest-cov",
69-
"pytest-xdist",
70-
"syrupy",
71-
"aiohttp",
64+
"deepdiff",
65+
"nbmake",
66+
"pexpect",
67+
"pytest",
68+
"pytest-cov",
69+
"pytest-xdist",
70+
"syrupy",
71+
# For HTTPPath tests.
72+
"aiohttp",
73+
"coiled",
7274
]
7375

7476
[project.urls]

src/_pytask/coiled_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
from typing import Callable
5+
6+
from attrs import define
7+
8+
try:
9+
from coiled.function import Function
10+
except ImportError:
11+
12+
@define
13+
class Function: # type: ignore[no-redef]
14+
cluster_kwargs: dict[str, Any]
15+
environ: dict[str, Any]
16+
function: Callable[..., Any] | None
17+
keepalive: str | None
18+
19+
20+
__all__ = ["Function"]
21+
22+
23+
def extract_coiled_function_kwargs(func: Function) -> dict[str, Any]:
24+
"""Extract the kwargs for a coiled function."""
25+
return {
26+
"cluster_kwargs": func._cluster_kwargs,
27+
"keepalive": func.keepalive,
28+
"environ": func._environ,
29+
"local": func._local,
30+
"name": func._name,
31+
}

src/_pytask/collect.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from rich.text import Text
1717
from upath import UPath
1818

19+
from _pytask.coiled_utils import Function
20+
from _pytask.coiled_utils import extract_coiled_function_kwargs
1921
from _pytask.collect_utils import create_name_of_python_node
2022
from _pytask.collect_utils import parse_dependencies_from_task_function
2123
from _pytask.collect_utils import parse_products_from_task_function
@@ -46,6 +48,7 @@
4648
from _pytask.reports import CollectionReport
4749
from _pytask.shared import find_duplicates
4850
from _pytask.shared import to_list
51+
from _pytask.shared import unwrap_task_function
4952
from _pytask.task_utils import COLLECTED_TASKS
5053
from _pytask.task_utils import task as task_decorator
5154
from _pytask.typing import is_task_function
@@ -311,15 +314,21 @@ def pytask_collect_task(
311314
)
312315

313316
markers = get_all_marks(obj)
314-
collection_id = obj.pytask_meta._id if hasattr(obj, "pytask_meta") else None
315-
after = obj.pytask_meta.after if hasattr(obj, "pytask_meta") else []
316-
is_generator = (
317-
obj.pytask_meta.is_generator if hasattr(obj, "pytask_meta") else False
318-
)
319317

320-
# Get the underlying function to avoid having different states of the function,
321-
# e.g. due to pytask_meta, in different layers of the wrapping.
322-
unwrapped = inspect.unwrap(obj)
318+
if hasattr(obj, "pytask_meta"):
319+
attributes = {
320+
**obj.pytask_meta.attributes,
321+
"collection_id": obj.pytask_meta._id,
322+
"after": obj.pytask_meta.after,
323+
"is_generator": obj.pytask_meta.is_generator,
324+
}
325+
else:
326+
attributes = {"collection_id": None, "after": [], "is_generator": False}
327+
328+
unwrapped = unwrap_task_function(obj)
329+
if isinstance(unwrapped, Function):
330+
attributes["coiled_kwargs"] = extract_coiled_function_kwargs(unwrapped)
331+
unwrapped = unwrap_task_function(unwrapped.function)
323332

324333
if path is None:
325334
return TaskWithoutPath(
@@ -328,11 +337,7 @@ def pytask_collect_task(
328337
depends_on=dependencies,
329338
produces=products,
330339
markers=markers,
331-
attributes={
332-
"collection_id": collection_id,
333-
"after": after,
334-
"is_generator": is_generator,
335-
},
340+
attributes=attributes,
336341
)
337342
return Task(
338343
base_name=name,
@@ -341,11 +346,7 @@ def pytask_collect_task(
341346
depends_on=dependencies,
342347
produces=products,
343348
markers=markers,
344-
attributes={
345-
"collection_id": collection_id,
346-
"after": after,
347-
"is_generator": is_generator,
348-
},
349+
attributes=attributes,
349350
)
350351
if isinstance(obj, PTask):
351352
return obj

src/_pytask/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class CollectionMetadata:
5050
"""
5151

5252
after: str | list[Callable[..., Any]] = field(factory=list)
53+
attributes: dict[str, Any] = field(factory=dict)
5354
is_generator: bool = False
5455
id_: str | None = None
5556
kwargs: dict[str, Any] = field(factory=dict)

src/_pytask/shared.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
from __future__ import annotations
44

55
import glob
6+
import inspect
67
from pathlib import Path
78
from typing import TYPE_CHECKING
89
from typing import Any
10+
from typing import Callable
911
from typing import Iterable
1012
from typing import Sequence
1113

1214
import click
1315

16+
from _pytask.coiled_utils import Function
1417
from _pytask.console import format_node_name
1518
from _pytask.console import format_task_name
1619
from _pytask.node_protocols import PNode
@@ -30,6 +33,7 @@
3033
"parse_paths",
3134
"reduce_names_of_multiple_nodes",
3235
"to_list",
36+
"unwrap_task_function",
3337
]
3438

3539

@@ -146,3 +150,13 @@ def convert_to_enum(value: Any, enum: type[Enum]) -> Enum:
146150
values = [e.value for e in enum]
147151
msg = f"Value {value!r} is not a valid {enum!r}. Valid values are {values}."
148152
raise ValueError(msg) from None
153+
154+
155+
def unwrap_task_function(obj: Any) -> Callable[..., Any]:
156+
"""Unwrap a task function.
157+
158+
Get the underlying function to avoid having different states of the function, e.g.
159+
due to pytask_meta, in different layers of the wrapping.
160+
161+
"""
162+
return inspect.unwrap(obj, stop=lambda x: isinstance(x, Function))

src/_pytask/task_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212

1313
import attrs
1414

15+
from _pytask.coiled_utils import Function
16+
from _pytask.coiled_utils import extract_coiled_function_kwargs
1517
from _pytask.console import get_file
1618
from _pytask.mark import Mark
1719
from _pytask.models import CollectionMetadata
1820
from _pytask.shared import find_duplicates
21+
from _pytask.shared import unwrap_task_function
1922
from _pytask.typing import is_task_function
2023

2124
if TYPE_CHECKING:
@@ -117,7 +120,12 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
117120
)
118121
raise ValueError(msg)
119122

120-
unwrapped = inspect.unwrap(func)
123+
unwrapped = unwrap_task_function(func)
124+
if isinstance(unwrapped, Function):
125+
coiled_kwargs = extract_coiled_function_kwargs(unwrapped)
126+
unwrapped = unwrap_task_function(unwrapped.function)
127+
else:
128+
coiled_kwargs = None
121129

122130
# We do not allow builtins as functions because we would need to use
123131
# ``inspect.stack`` to infer their caller location and they are unable to carry
@@ -145,7 +153,7 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
145153
unwrapped.pytask_meta.produces = produces
146154
unwrapped.pytask_meta.after = parsed_after
147155
else:
148-
unwrapped.pytask_meta = CollectionMetadata(
156+
unwrapped.pytask_meta = CollectionMetadata( # type: ignore[attr-defined]
149157
after=parsed_after,
150158
is_generator=is_generator,
151159
id_=id,
@@ -155,6 +163,9 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
155163
produces=produces,
156164
)
157165

166+
if coiled_kwargs and hasattr(unwrapped, "pytask_meta"):
167+
unwrapped.pytask_meta.attributes["coiled_kwargs"] = coiled_kwargs
168+
158169
# Store it in the global variable ``COLLECTED_TASKS`` to avoid garbage
159170
# collection when the function definition is overwritten in a loop.
160171
COLLECTED_TASKS[path].append(unwrapped)

tests/test_shared.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

3+
import functools
34
import textwrap
45
from contextlib import ExitStack as does_not_raise # noqa: N813
56

7+
import coiled
68
import pytest
79
from _pytask.shared import convert_to_enum
810
from _pytask.shared import find_duplicates
11+
from _pytask.shared import unwrap_task_function
912
from pytask import ExitCode
1013
from pytask import ShowCapture
1114
from pytask import build
@@ -49,3 +52,29 @@ def test_convert_to_enum(value, enum, expectation, expected):
4952
with expectation:
5053
result = convert_to_enum(value, enum)
5154
assert result == expected
55+
56+
57+
@pytest.mark.unit()
58+
def test_unwrap_task_function():
59+
def task():
60+
pass
61+
62+
# partialed functions are only unwrapped after wraps.
63+
partialed = functools.wraps(task)(functools.partial(task))
64+
assert unwrap_task_function(partialed) is task
65+
66+
partialed = functools.partial(task)
67+
assert unwrap_task_function(partialed) is partialed
68+
69+
def decorator(func):
70+
@functools.wraps(func)
71+
def wrapper():
72+
return func()
73+
74+
return wrapper
75+
76+
decorated = decorator(task)
77+
assert unwrap_task_function(decorated) is task
78+
79+
coiled_decorated = coiled.function()(task)
80+
assert unwrap_task_function(coiled_decorated) is coiled_decorated

0 commit comments

Comments
 (0)