Skip to content

Stop unwrapping coiled functions. #595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ src/_pytask/_version.py
*.pkl

tests/test_jupyter/*.txt
.mypy_cache
.pytest_cache
.ruff_cache
.venv
1 change: 1 addition & 0 deletions docs/source/changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
- {pull}`593` recreate `PythonNode`s every run since they carry the `_NoDefault` enum as
the value whose state is `None`.
- {pull}`594` publishes `NodeLoadError`.
- {pull}`595` stops unwrapping task functions until a `coiled.function.Function`.
- {pull}`596` add project management with rye.
- {pull}`598` replaces requests with httpx.

Expand Down
42 changes: 22 additions & 20 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"attrs>=21.3",
"click",
"click-default-group",
"networkx>=2.4",
"optree>=0.9",
"packaging",
"pluggy>=1.3.0",
"rich",
"sqlalchemy>=2",
'tomli>=1; python_version < "3.11"',
'typing-extensions; python_version < "3.9"',
"universal-pathlib>=0.2.2",
"attrs>=21.3",
"click",
"click-default-group",
"networkx>=2.4",
"optree>=0.9",
"packaging",
"pluggy>=1.3.0",
"rich",
"sqlalchemy>=2",
'tomli>=1; python_version < "3.11"',
'typing-extensions; python_version < "3.9"',
"universal-pathlib>=0.2.2",
]

[project.readme]
Expand Down Expand Up @@ -61,14 +61,16 @@ docs = [
"sphinxext-opengraph",
]
test = [
"deepdiff",
"nbmake",
"pexpect",
"pytest",
"pytest-cov",
"pytest-xdist",
"syrupy",
"aiohttp",
"deepdiff",
"nbmake",
"pexpect",
"pytest",
"pytest-cov",
"pytest-xdist",
"syrupy",
# For HTTPPath tests.
"aiohttp",
"coiled",
]

[project.urls]
Expand Down
31 changes: 31 additions & 0 deletions src/_pytask/coiled_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from typing import Any
from typing import Callable

from attrs import define

try:
from coiled.function import Function
except ImportError:

@define
class Function: # type: ignore[no-redef]
cluster_kwargs: dict[str, Any]
environ: dict[str, Any]
function: Callable[..., Any] | None
keepalive: str | None


__all__ = ["Function"]


def extract_coiled_function_kwargs(func: Function) -> dict[str, Any]:
"""Extract the kwargs for a coiled function."""
return {
"cluster_kwargs": func._cluster_kwargs,
"keepalive": func.keepalive,
"environ": func._environ,
"local": func._local,
"name": func._name,
}
37 changes: 19 additions & 18 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from rich.text import Text
from upath import UPath

from _pytask.coiled_utils import Function
from _pytask.coiled_utils import extract_coiled_function_kwargs
from _pytask.collect_utils import create_name_of_python_node
from _pytask.collect_utils import parse_dependencies_from_task_function
from _pytask.collect_utils import parse_products_from_task_function
Expand Down Expand Up @@ -46,6 +48,7 @@
from _pytask.reports import CollectionReport
from _pytask.shared import find_duplicates
from _pytask.shared import to_list
from _pytask.shared import unwrap_task_function
from _pytask.task_utils import COLLECTED_TASKS
from _pytask.task_utils import task as task_decorator
from _pytask.typing import is_task_function
Expand Down Expand Up @@ -311,15 +314,21 @@ def pytask_collect_task(
)

markers = get_all_marks(obj)
collection_id = obj.pytask_meta._id if hasattr(obj, "pytask_meta") else None
after = obj.pytask_meta.after if hasattr(obj, "pytask_meta") else []
is_generator = (
obj.pytask_meta.is_generator if hasattr(obj, "pytask_meta") else False
)

# Get the underlying function to avoid having different states of the function,
# e.g. due to pytask_meta, in different layers of the wrapping.
unwrapped = inspect.unwrap(obj)
if hasattr(obj, "pytask_meta"):
attributes = {
**obj.pytask_meta.attributes,
"collection_id": obj.pytask_meta._id,
"after": obj.pytask_meta.after,
"is_generator": obj.pytask_meta.is_generator,
}
else:
attributes = {"collection_id": None, "after": [], "is_generator": False}

unwrapped = unwrap_task_function(obj)
if isinstance(unwrapped, Function):
attributes["coiled_kwargs"] = extract_coiled_function_kwargs(unwrapped)
unwrapped = unwrap_task_function(unwrapped.function)

if path is None:
return TaskWithoutPath(
Expand All @@ -328,11 +337,7 @@ def pytask_collect_task(
depends_on=dependencies,
produces=products,
markers=markers,
attributes={
"collection_id": collection_id,
"after": after,
"is_generator": is_generator,
},
attributes=attributes,
)
return Task(
base_name=name,
Expand All @@ -341,11 +346,7 @@ def pytask_collect_task(
depends_on=dependencies,
produces=products,
markers=markers,
attributes={
"collection_id": collection_id,
"after": after,
"is_generator": is_generator,
},
attributes=attributes,
)
if isinstance(obj, PTask):
return obj
Expand Down
1 change: 1 addition & 0 deletions src/_pytask/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class CollectionMetadata:
"""

after: str | list[Callable[..., Any]] = field(factory=list)
attributes: dict[str, Any] = field(factory=dict)
is_generator: bool = False
id_: str | None = None
kwargs: dict[str, Any] = field(factory=dict)
Expand Down
14 changes: 14 additions & 0 deletions src/_pytask/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
from __future__ import annotations

import glob
import inspect
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Sequence

import click

from _pytask.coiled_utils import Function
from _pytask.console import format_node_name
from _pytask.console import format_task_name
from _pytask.node_protocols import PNode
Expand All @@ -30,6 +33,7 @@
"parse_paths",
"reduce_names_of_multiple_nodes",
"to_list",
"unwrap_task_function",
]


Expand Down Expand Up @@ -146,3 +150,13 @@ def convert_to_enum(value: Any, enum: type[Enum]) -> Enum:
values = [e.value for e in enum]
msg = f"Value {value!r} is not a valid {enum!r}. Valid values are {values}."
raise ValueError(msg) from None


def unwrap_task_function(obj: Any) -> Callable[..., Any]:
"""Unwrap a task function.

Get the underlying function to avoid having different states of the function, e.g.
due to pytask_meta, in different layers of the wrapping.

"""
return inspect.unwrap(obj, stop=lambda x: isinstance(x, Function))
15 changes: 13 additions & 2 deletions src/_pytask/task_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@

import attrs

from _pytask.coiled_utils import Function
from _pytask.coiled_utils import extract_coiled_function_kwargs
from _pytask.console import get_file
from _pytask.mark import Mark
from _pytask.models import CollectionMetadata
from _pytask.shared import find_duplicates
from _pytask.shared import unwrap_task_function
from _pytask.typing import is_task_function

if TYPE_CHECKING:
Expand Down Expand Up @@ -117,7 +120,12 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
)
raise ValueError(msg)

unwrapped = inspect.unwrap(func)
unwrapped = unwrap_task_function(func)
if isinstance(unwrapped, Function):
coiled_kwargs = extract_coiled_function_kwargs(unwrapped)
unwrapped = unwrap_task_function(unwrapped.function)
else:
coiled_kwargs = None

# We do not allow builtins as functions because we would need to use
# ``inspect.stack`` to infer their caller location and they are unable to carry
Expand Down Expand Up @@ -145,7 +153,7 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
unwrapped.pytask_meta.produces = produces
unwrapped.pytask_meta.after = parsed_after
else:
unwrapped.pytask_meta = CollectionMetadata(
unwrapped.pytask_meta = CollectionMetadata( # type: ignore[attr-defined]
after=parsed_after,
is_generator=is_generator,
id_=id,
Expand All @@ -155,6 +163,9 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
produces=produces,
)

if coiled_kwargs and hasattr(unwrapped, "pytask_meta"):
unwrapped.pytask_meta.attributes["coiled_kwargs"] = coiled_kwargs

# Store it in the global variable ``COLLECTED_TASKS`` to avoid garbage
# collection when the function definition is overwritten in a loop.
COLLECTED_TASKS[path].append(unwrapped)
Expand Down
29 changes: 29 additions & 0 deletions tests/test_shared.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import functools
import textwrap
from contextlib import ExitStack as does_not_raise # noqa: N813

import coiled
import pytest
from _pytask.shared import convert_to_enum
from _pytask.shared import find_duplicates
from _pytask.shared import unwrap_task_function
from pytask import ExitCode
from pytask import ShowCapture
from pytask import build
Expand Down Expand Up @@ -49,3 +52,29 @@ def test_convert_to_enum(value, enum, expectation, expected):
with expectation:
result = convert_to_enum(value, enum)
assert result == expected


@pytest.mark.unit()
def test_unwrap_task_function():
def task():
pass

# partialed functions are only unwrapped after wraps.
partialed = functools.wraps(task)(functools.partial(task))
assert unwrap_task_function(partialed) is task

partialed = functools.partial(task)
assert unwrap_task_function(partialed) is partialed

def decorator(func):
@functools.wraps(func)
def wrapper():
return func()

return wrapper

decorated = decorator(task)
assert unwrap_task_function(decorated) is task

coiled_decorated = coiled.function()(task)
assert unwrap_task_function(coiled_decorated) is coiled_decorated