Skip to content

Commit 58d940e

Browse files
authored
Add is_product to PNode.load(). (#472)
1 parent 2d543bc commit 58d940e

15 files changed

+179
-25
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ repos:
5151
rev: v0.1.1
5252
hooks:
5353
- id: ruff
54+
args: [--unsafe-fixes]
5455
- repo: https://github.com/dosisod/refurb
5556
rev: v1.22.1
5657
hooks:

docs/source/changes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
1919
- {pull}`463` raise error when a task function is not defined inside the loop body.
2020
- {pull}`464` improves pinned dependencies.
2121
- {pull}`465` adds test to ensure internal tracebacks are removed by reports.
22+
- {pull}`472` adds `is_product` to {meth}`PNode.load`.
2223

2324
## 0.4.1 - 2023-10-11
2425

docs/source/how_to_guides/writing_custom_nodes.md

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ to inputs and outputs and call {func}`pandas.read_pickle` and
2020
To remove IO operations from the task and delegate them to pytask, we will write a
2121
`PickleNode` that automatically loads and stores Python objects.
2222

23-
We will also use the feature explained in {doc}`using_task_returns` to define products
24-
of the task function via the function's return value.
25-
2623
And we pass the value to `df` via {obj}`Annotated` to preserve the type hint.
2724

2825
The result will be the following task.
@@ -37,12 +34,28 @@ The result will be the following task.
3734

3835
:::
3936

37+
:::{tab-item} Python 3.10+ & Return
38+
:sync: python310plus
39+
40+
```{literalinclude} ../../../docs_src/how_to_guides/writing_custom_nodes_example_2_py310_return.py
41+
```
42+
43+
:::
44+
4045
:::{tab-item} Python 3.8+
4146
:sync: python38plus
4247

4348
```{literalinclude} ../../../docs_src/how_to_guides/writing_custom_nodes_example_2_py38.py
4449
```
4550

51+
:::
52+
53+
:::{tab-item} Python 3.8+ & Return
54+
:sync: python38plus
55+
56+
```{literalinclude} ../../../docs_src/how_to_guides/writing_custom_nodes_example_2_py38_return.py
57+
```
58+
4659
:::
4760
::::
4861

@@ -97,7 +110,12 @@ Here are some explanations.
97110
the value changes, pytask knows it needs to regenerate the workflow. We can use
98111
the timestamp of when the node was last modified.
99112
- pytask calls {meth}`PickleNode.load` when it collects the values of function arguments
100-
to run the function. In our example, we read the file and unpickle the data.
113+
to run the function. The argument `is_product` signals that the node is loaded as a
114+
product with a {class}`~pytask.Product` annotation or via `produces`.
115+
116+
When the node is loaded as a dependency, we want to inject the value of the pickle
117+
file. In the other case, the node returns itself so users can call
118+
{meth}`PickleNode.save` themselves.
101119
- {meth}`PickleNode.save` is called when a task function returns and allows to save the
102120
return values.
103121

docs_src/how_to_guides/writing_custom_nodes_example_2_py310.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Annotated
33

44
import pandas as pd
5+
from pytask import Product
56

67

78
class PickleNode:
@@ -13,6 +14,7 @@ class PickleNode:
1314

1415

1516
def task_example(
16-
df: Annotated[pd.DataFrame, in_node]
17-
) -> Annotated[pd.DataFrame, out_node]:
18-
return df.apply(...)
17+
df: Annotated[pd.DataFrame, in_node], out: Annotated[PickleNode, out_node, Product]
18+
) -> None:
19+
transformed = df.apply(...)
20+
out.save(transformed)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from pathlib import Path
2+
from typing import Annotated
3+
4+
import pandas as pd
5+
6+
7+
class PickleNode:
8+
...
9+
10+
11+
in_node = PickleNode.from_path(Path(__file__).parent / "in.pkl")
12+
out_node = PickleNode.from_path(Path(__file__).parent / "out.pkl")
13+
14+
15+
def task_example(
16+
df: Annotated[pd.DataFrame, in_node]
17+
) -> Annotated[pd.DataFrame, out_node]:
18+
return df.apply(...)

docs_src/how_to_guides/writing_custom_nodes_example_2_py38.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22

33
import pandas as pd
4+
from pytask import Product
45
from typing_extensions import Annotated
56

67

@@ -13,6 +14,7 @@ class PickleNode:
1314

1415

1516
def task_example(
16-
df: Annotated[pd.DataFrame, in_node]
17-
) -> Annotated[pd.DataFrame, out_node]:
18-
return df.apply(...)
17+
df: Annotated[pd.DataFrame, in_node], out: Annotated[PickleNode, out_node, Product]
18+
) -> None:
19+
transformed = df.apply(...)
20+
out.save(transformed)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from pathlib import Path
2+
3+
import pandas as pd
4+
from typing_extensions import Annotated
5+
6+
7+
class PickleNode:
8+
...
9+
10+
11+
in_node = PickleNode.from_path(Path(__file__).parent / "in.pkl")
12+
out_node = PickleNode.from_path(Path(__file__).parent / "out.pkl")
13+
14+
15+
def task_example(
16+
df: Annotated[pd.DataFrame, in_node]
17+
) -> Annotated[pd.DataFrame, out_node]:
18+
return df.apply(...)

docs_src/how_to_guides/writing_custom_nodes_example_3_py310.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ def state(self) -> str | None:
3333
return str(self.path.stat().st_mtime)
3434
return None
3535

36-
def load(self) -> Path:
36+
def load(self, is_product: bool) -> Path:
3737
"""Load the value from the path."""
38+
if is_product:
39+
return self
3840
return pickle.loads(self.path.read_bytes())
3941

4042
def save(self, value: Any) -> None:

docs_src/how_to_guides/writing_custom_nodes_example_3_py38.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ def state(self) -> Optional[str]:
3434
return str(self.path.stat().st_mtime)
3535
return None
3636

37-
def load(self) -> Path:
37+
def load(self, is_product: bool) -> Path:
3838
"""Load the value from the path."""
39+
if is_product:
40+
return self
3941
return pickle.loads(self.path.read_bytes())
4042

4143
def save(self, value: Any) -> None:

src/_pytask/execute.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None:
144144
raise WouldBeExecuted
145145

146146

147-
def _safe_load(node: PNode, task: PTask) -> Any:
147+
def _safe_load(node: PNode, task: PTask, is_product: bool) -> Any:
148148
try:
149-
return node.load()
149+
return node.load(is_product=is_product)
150150
except Exception as e: # noqa: BLE001
151151
task_name = getattr(task, "display_name", task.name)
152152
msg = f"Exception while loading node {node.name!r} of task {task_name!r}"
@@ -163,11 +163,11 @@ def pytask_execute_task(session: Session, task: PTask) -> bool:
163163

164164
kwargs = {}
165165
for name, value in task.depends_on.items():
166-
kwargs[name] = tree_map(lambda x: _safe_load(x, task), value)
166+
kwargs[name] = tree_map(lambda x: _safe_load(x, task, False), value)
167167

168168
for name, value in task.produces.items():
169169
if name in parameters:
170-
kwargs[name] = tree_map(lambda x: _safe_load(x, task), value)
170+
kwargs[name] = tree_map(lambda x: _safe_load(x, task, True), value)
171171

172172
out = task.execute(**kwargs)
173173

src/_pytask/node_protocols.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,15 @@ def state(self) -> str | None:
3939
class PNode(MetaNode, Protocol):
4040
"""Protocol for nodes."""
4141

42-
def load(self) -> Any:
43-
"""Return the value of the node that will be injected into the task."""
42+
def load(self, is_product: bool) -> Any:
43+
"""Return the value of the node that will be injected into the task.
44+
45+
Parameters
46+
----------
47+
is_product
48+
Indicates whether the node is loaded as a dependency or as a product.
49+
50+
"""
4451
...
4552

4653
def save(self, value: Any) -> Any:

src/_pytask/nodes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def state(self) -> str | None:
173173
return str(self.path.stat().st_mtime)
174174
return None
175175

176-
def load(self) -> Path:
176+
def load(self, is_product: bool = False) -> Path: # noqa: ARG002
177177
"""Load the value."""
178178
return self.path
179179

@@ -207,8 +207,10 @@ class PythonNode(PNode):
207207
value: Any | NoDefault = no_default
208208
hash: bool | Callable[[Any], bool] = False # noqa: A003
209209

210-
def load(self) -> Any:
210+
def load(self, is_product: bool = False) -> Any:
211211
"""Load the value."""
212+
if is_product:
213+
return self
212214
if isinstance(self.value, PythonNode):
213215
return self.value.load()
214216
return self.value

tests/test_execute.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def task_example() -> Annotated[str, (node1, node2)]:
606606

607607

608608
@pytest.mark.end_to_end()
609-
def test_return_with_custom_type_annotation_as_return(runner, tmp_path):
609+
def test_return_with_custom_node_and_return_annotation(runner, tmp_path):
610610
source = """
611611
from __future__ import annotations
612612
@@ -626,7 +626,9 @@ def state(self) -> str | None:
626626
return str(self.path.stat().st_mtime)
627627
return None
628628
629-
def load(self) -> Any:
629+
def load(self, is_product) -> Any:
630+
if is_product:
631+
return self
630632
return pickle.loads(self.path.read_bytes())
631633
632634
def save(self, value: Any) -> None:
@@ -645,6 +647,49 @@ def task_example() -> Annotated[int, node]:
645647
assert data == 1
646648

647649

650+
@pytest.mark.end_to_end()
651+
def test_return_with_custom_node_with_product_annotation(runner, tmp_path):
652+
source = """
653+
from __future__ import annotations
654+
655+
from pathlib import Path
656+
import pickle
657+
from typing import Any
658+
from typing_extensions import Annotated
659+
import attrs
660+
from pytask import Product
661+
662+
@attrs.define
663+
class PickleNode:
664+
name: str
665+
path: Path
666+
667+
def state(self) -> str | None:
668+
if self.path.exists():
669+
return str(self.path.stat().st_mtime)
670+
return None
671+
672+
def load(self, is_product) -> Any:
673+
if is_product:
674+
return self
675+
return pickle.loads(self.path.read_bytes())
676+
677+
def save(self, value: Any) -> None:
678+
self.path.write_bytes(pickle.dumps(value))
679+
680+
node = PickleNode("pickled_data", Path(__file__).parent.joinpath("data.pkl"))
681+
682+
def task_example(node: Annotated[PickleNode, node, Product]) -> None:
683+
node.save(1)
684+
"""
685+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
686+
result = runner.invoke(cli, [tmp_path.as_posix()])
687+
assert result.exit_code == ExitCode.OK
688+
689+
data = pickle.loads(tmp_path.joinpath("data.pkl").read_bytes()) # noqa: S301
690+
assert data == 1
691+
692+
648693
@pytest.mark.end_to_end()
649694
def test_error_when_return_pytree_mismatch(runner, tmp_path):
650695
source = """
@@ -835,3 +880,23 @@ def task_example(
835880

836881
# Test that traceback is hidden.
837882
assert "_pytask/execute.py" not in result.output
883+
884+
885+
def test_python_node_as_product_with_product_annotation(runner, tmp_path):
886+
source = """
887+
from typing_extensions import Annotated
888+
from pytask import Product, PythonNode
889+
from pathlib import Path
890+
891+
node = PythonNode()
892+
893+
def task_create_string(node: Annotated[PythonNode, node, Product]) -> None:
894+
node.save("Hello, World!")
895+
896+
def task_write_file(text: Annotated[str, node]) -> Annotated[str, Path("file.txt")]:
897+
return text
898+
"""
899+
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
900+
result = runner.invoke(cli, [tmp_path.as_posix()])
901+
assert result.exit_code == ExitCode.OK
902+
assert tmp_path.joinpath("file.txt").read_text() == "Hello, World!"

tests/test_node_protocols.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class CustomNode:
2424
def state(self):
2525
return self.value
2626
27-
def load(self):
27+
def load(self, is_product):
2828
return self.value
2929
3030
def save(self, value):
@@ -61,7 +61,7 @@ class PickleFile:
6161
def state(self):
6262
return str(self.path.stat().st_mtime)
6363
64-
def load(self):
64+
def load(self, is_product):
6565
with self.path.open("rb") as f:
6666
out = pickle.load(f)
6767
return out

tests/test_nodes.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

33
import pytest
4-
from _pytask.nodes import PythonNode
4+
from pytask import PathNode
5+
from pytask import PNode
6+
from pytask import PPathNode
7+
from pytask import PythonNode
58

69

710
@pytest.mark.unit()
@@ -19,3 +22,16 @@ def test_hash_of_python_node(value, hash_, expected):
1922
node = PythonNode(name="test", value=value, hash=hash_)
2023
state = node.state()
2124
assert state == expected
25+
26+
27+
@pytest.mark.parametrize(
28+
("node", "protocol", "expected"),
29+
[
30+
(PathNode, PNode, True),
31+
(PathNode, PPathNode, True),
32+
(PythonNode, PNode, True),
33+
(PythonNode, PPathNode, False),
34+
],
35+
)
36+
def test_comply_with_protocol(node, protocol, expected):
37+
assert isinstance(node, protocol) is expected

0 commit comments

Comments
 (0)