Skip to content

Commit 2279aff

Browse files
authored
Add products with typing.Annotation. (#394)
1 parent ae31b65 commit 2279aff

18 files changed

+369
-22
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@ repos:
2121
hooks:
2222
- id: python-check-blanket-noqa
2323
- id: python-check-mock-methods
24-
- id: python-no-eval
25-
exclude: expression.py
2624
- id: python-no-log-warn
27-
- id: python-use-type-annotations
2825
- id: text-unicode-replacement-char
2926
- repo: https://github.com/asottile/reorder-python-imports
3027
rev: v3.9.0

docs/rtd_environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies:
3030
- rich
3131
- sqlalchemy >=1.4.36
3232
- tomli >=1.0.0
33+
- typing_extensions
3334

3435
- pip:
3536
- ../

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies:
2020
- rich
2121
- sqlalchemy >=1.4.36
2222
- tomli >=1.0.0
23+
- typing_extensions
2324

2425
# Misc
2526
- black

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ convention = "numpy"
8383

8484

8585
[tool.pytest.ini_options]
86+
addopts = ["--doctest-modules"]
8687
testpaths = ["src", "tests"]
8788
markers = [
8889
"wip: Tests that are work-in-progress.",

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ install_requires =
4040
rich
4141
sqlalchemy>=1.4.36
4242
tomli>=1.0.0
43+
typing-extensions
4344
python_requires = >=3.8
4445
include_package_data = True
4546
package_dir =

src/_pytask/_inspect.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from __future__ import annotations
2+
3+
import functools
4+
import sys
5+
import types
6+
from typing import Any
7+
from typing import Callable
8+
from typing import Mapping
9+
10+
11+
__all__ = ["get_annotations"]
12+
13+
14+
if sys.version_info >= (3, 10):
15+
from inspect import get_annotations
16+
else:
17+
18+
def get_annotations( # noqa: C901, PLR0912, PLR0915
19+
obj: Callable[..., object] | type[Any] | types.ModuleType,
20+
*,
21+
globals: Mapping[str, Any] | None = None, # noqa: A002
22+
locals: Mapping[str, Any] | None = None, # noqa: A002
23+
eval_str: bool = False,
24+
) -> dict[str, Any]:
25+
"""Compute the annotations dict for an object.
26+
27+
obj may be a callable, class, or module.
28+
Passing in an object of any other type raises TypeError.
29+
30+
Returns a dict. get_annotations() returns a new dict every time
31+
it's called; calling it twice on the same object will return two
32+
different but equivalent dicts.
33+
34+
This function handles several details for you:
35+
36+
* If eval_str is true, values of type str will
37+
be un-stringized using eval(). This is intended
38+
for use with stringized annotations
39+
("from __future__ import annotations").
40+
* If obj doesn't have an annotations dict, returns an
41+
empty dict. (Functions and methods always have an
42+
annotations dict; classes, modules, and other types of
43+
callables may not.)
44+
* Ignores inherited annotations on classes. If a class
45+
doesn't have its own annotations dict, returns an empty dict.
46+
* All accesses to object members and dict values are done
47+
using getattr() and dict.get() for safety.
48+
* Always, always, always returns a freshly-created dict.
49+
50+
eval_str controls whether or not values of type str are replaced
51+
with the result of calling eval() on those values:
52+
53+
* If eval_str is true, eval() is called on values of type str.
54+
* If eval_str is false (the default), values of type str are unchanged.
55+
56+
globals and locals are passed in to eval(); see the documentation
57+
for eval() for more information. If either globals or locals is
58+
None, this function may replace that value with a context-specific
59+
default, contingent on type(obj):
60+
61+
* If obj is a module, globals defaults to obj.__dict__.
62+
* If obj is a class, globals defaults to
63+
sys.modules[obj.__module__].__dict__ and locals
64+
defaults to the obj class namespace.
65+
* If obj is a callable, globals defaults to obj.__globals__,
66+
although if obj is a wrapped function (using
67+
functools.update_wrapper()) it is first unwrapped.
68+
"""
69+
if isinstance(obj, type):
70+
# class
71+
obj_dict = getattr(obj, "__dict__", None)
72+
if obj_dict and hasattr(obj_dict, "get"):
73+
ann = obj_dict.get("__annotations__", None)
74+
if isinstance(ann, types.GetSetDescriptorType):
75+
ann = None
76+
else:
77+
ann = None
78+
79+
obj_globals = None
80+
module_name = getattr(obj, "__module__", None)
81+
if module_name:
82+
module = sys.modules.get(module_name, None)
83+
if module:
84+
obj_globals = getattr(module, "__dict__", None)
85+
obj_locals = dict(vars(obj))
86+
unwrap = obj
87+
elif isinstance(obj, types.ModuleType):
88+
# module
89+
ann = getattr(obj, "__annotations__", None)
90+
obj_globals = obj.__dict__
91+
obj_locals = None
92+
unwrap = None
93+
elif callable(obj):
94+
# this includes types.Function, types.BuiltinFunctionType,
95+
# types.BuiltinMethodType, functools.partial, functools.singledispatch,
96+
# "class funclike" from Lib/test/test_inspect... on and on it goes.
97+
ann = getattr(obj, "__annotations__", None)
98+
obj_globals = getattr(obj, "__globals__", None)
99+
obj_locals = None
100+
unwrap = obj
101+
else:
102+
raise TypeError(f"{obj!r} is not a module, class, or callable.")
103+
104+
if ann is None:
105+
return {}
106+
107+
if not isinstance(ann, dict):
108+
raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
109+
110+
if not ann:
111+
return {}
112+
113+
if not eval_str:
114+
return dict(ann)
115+
116+
if unwrap is not None:
117+
while True:
118+
if hasattr(unwrap, "__wrapped__"):
119+
unwrap = unwrap.__wrapped__
120+
continue
121+
if isinstance(unwrap, functools.partial):
122+
unwrap = unwrap.func
123+
continue
124+
break
125+
if hasattr(unwrap, "__globals__"):
126+
obj_globals = unwrap.__globals__
127+
128+
if globals is None:
129+
globals = obj_globals # noqa: A001
130+
if locals is None:
131+
locals = obj_locals # noqa: A001
132+
133+
eval_func = eval
134+
return_value = {
135+
key: value
136+
if not isinstance(value, str)
137+
else eval_func(value, globals, locals)
138+
for key, value in ann.items()
139+
}
140+
return return_value

src/_pytask/collect.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from _pytask.collect_utils import parse_dependencies_from_task_function
1616
from _pytask.collect_utils import parse_nodes
1717
from _pytask.collect_utils import parse_products_from_task_function
18-
from _pytask.collect_utils import produces
1918
from _pytask.config import hookimpl
2019
from _pytask.config import IS_FILE_SYSTEM_CASE_SENSITIVE
2120
from _pytask.console import console
@@ -178,10 +177,7 @@ def pytask_collect_task(
178177
session, path, name, obj
179178
)
180179

181-
if has_mark(obj, "produces"):
182-
products = parse_nodes(session, path, name, obj, produces)
183-
else:
184-
products = parse_products_from_task_function(session, path, name, obj)
180+
products = parse_products_from_task_function(session, path, name, obj)
185181

186182
markers = obj.pytask_meta.markers if hasattr(obj, "pytask_meta") else []
187183

src/_pytask/collect_utils.py

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,19 @@
1111
from typing import Iterable
1212
from typing import TYPE_CHECKING
1313

14+
from _pytask._inspect import get_annotations
1415
from _pytask.exceptions import NodeNotCollectedError
16+
from _pytask.mark_utils import has_mark
1517
from _pytask.mark_utils import remove_marks
18+
from _pytask.nodes import ProductType
1619
from _pytask.nodes import PythonNode
1720
from _pytask.shared import find_duplicates
1821
from _pytask.task_utils import parse_keyword_arguments_from_signature_defaults
1922
from attrs import define
2023
from attrs import field
2124
from pybaum.tree_util import tree_map
25+
from typing_extensions import Annotated
26+
from typing_extensions import get_origin
2227

2328

2429
if TYPE_CHECKING:
@@ -211,8 +216,12 @@ def parse_dependencies_from_task_function(
211216
kwargs = {**signature_defaults, **task_kwargs}
212217
kwargs.pop("produces", None)
213218

219+
parameters_with_product_annot = _find_args_with_product_annotation(obj)
220+
214221
dependencies = {}
215222
for name, value in kwargs.items():
223+
if name in parameters_with_product_annot:
224+
continue
216225
parsed_value = tree_map(
217226
lambda x: _collect_dependencies(session, path, name, x), value # noqa: B023
218227
)
@@ -223,29 +232,108 @@ def parse_dependencies_from_task_function(
223232
return dependencies
224233

225234

235+
_ERROR_MULTIPLE_PRODUCT_DEFINITIONS = (
236+
"The task uses multiple ways to define products. Products should be defined with "
237+
"either\n\n- 'typing.Annotated[Path(...), Product]' (recommended)\n"
238+
"- '@pytask.mark.task(kwargs={'produces': Path(...)})'\n"
239+
"- as a default argument for 'produces': 'produces = Path(...)'\n"
240+
"- '@pytask.mark.produces(Path(...))' (deprecated).\n\n"
241+
"Read more about products in the documentation: https://tinyurl.com/yrezszr4."
242+
)
243+
244+
226245
def parse_products_from_task_function(
227246
session: Session, path: Path, name: str, obj: Any
228247
) -> dict[str, Any]:
229-
"""Parse dependencies from task function."""
248+
"""Parse products from task function.
249+
250+
Raises
251+
------
252+
NodeNotCollectedError
253+
If multiple ways were used to specify products.
254+
255+
"""
256+
has_produces_decorator = False
257+
has_task_decorator = False
258+
has_signature_default = False
259+
has_annotation = False
260+
out = {}
261+
262+
if has_mark(obj, "produces"):
263+
has_produces_decorator = True
264+
nodes = parse_nodes(session, path, name, obj, produces)
265+
out = {"produces": nodes}
266+
230267
task_kwargs = obj.pytask_meta.kwargs if hasattr(obj, "pytask_meta") else {}
231268
if "produces" in task_kwargs:
232-
return tree_map(
269+
collected_products = tree_map(
233270
lambda x: _collect_product(session, path, name, x, is_string_allowed=True),
234271
task_kwargs["produces"],
235272
)
273+
out = {"produces": collected_products}
236274

237275
parameters = inspect.signature(obj).parameters
238-
if "produces" in parameters:
276+
277+
if not has_mark(obj, "task") and "produces" in parameters:
239278
parameter = parameters["produces"]
240279
if parameter.default is not parameter.empty:
280+
has_signature_default = True
241281
# Use _collect_new_node to not collect strings.
242-
return tree_map(
282+
collected_products = tree_map(
243283
lambda x: _collect_product(
244284
session, path, name, x, is_string_allowed=False
245285
),
246286
parameter.default,
247287
)
248-
return {}
288+
out = {"produces": collected_products}
289+
290+
parameters_with_product_annot = _find_args_with_product_annotation(obj)
291+
if parameters_with_product_annot:
292+
has_annotation = True
293+
for parameter_name in parameters_with_product_annot:
294+
parameter = parameters[parameter_name]
295+
if parameter.default is not parameter.empty:
296+
# Use _collect_new_node to not collect strings.
297+
collected_products = tree_map(
298+
lambda x: _collect_product(
299+
session, path, name, x, is_string_allowed=False
300+
),
301+
parameter.default,
302+
)
303+
out = {parameter_name: collected_products}
304+
305+
if (
306+
sum(
307+
(
308+
has_produces_decorator,
309+
has_task_decorator,
310+
has_signature_default,
311+
has_annotation,
312+
)
313+
)
314+
>= 2 # noqa: PLR2004
315+
):
316+
raise NodeNotCollectedError(_ERROR_MULTIPLE_PRODUCT_DEFINITIONS)
317+
318+
return out
319+
320+
321+
def _find_args_with_product_annotation(func: Callable[..., Any]) -> list[str]:
322+
"""Find args with product annotation."""
323+
annotations = get_annotations(func, eval_str=True)
324+
metas = {
325+
name: annotation.__metadata__
326+
for name, annotation in annotations.items()
327+
if get_origin(annotation) is Annotated
328+
}
329+
330+
args_with_product_annot = []
331+
for name, meta in metas.items():
332+
has_product_annot = any(isinstance(i, ProductType) for i in meta)
333+
if has_product_annot:
334+
args_with_product_annot.append(name)
335+
336+
return args_with_product_annot
249337

250338

251339
def _collect_old_dependencies(
@@ -331,9 +419,10 @@ def _collect_product(
331419
# The parameter defaults only support Path objects.
332420
if not isinstance(node, Path) and not is_string_allowed:
333421
raise ValueError(
334-
"If you use 'produces' as an argument of a task, it can only accept values "
335-
"of type 'pathlib.Path' or the same value nested in "
336-
f"tuples, lists, and dictionaries. Here, {node} has type {type(node)}."
422+
"If you use 'produces' as a function argument of a task and pass values as "
423+
"function defaults, it can only accept values of type 'pathlib.Path' or "
424+
"the same value nested in tuples, lists, and dictionaries. Here, "
425+
f"{node!r} has type {type(node)}."
337426
)
338427

339428
if isinstance(node, str):

src/_pytask/execute.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,9 @@ def pytask_execute_task(session: Session, task: Task) -> bool:
149149
for name, value in task.depends_on.items():
150150
kwargs[name] = tree_map(lambda x: x.value, value)
151151

152-
if task.produces and "produces" in parameters:
153-
kwargs["produces"] = tree_map(lambda x: x.value, task.produces)
152+
for name, value in task.produces.items():
153+
if name in parameters:
154+
kwargs[name] = tree_map(lambda x: x.value, value)
154155

155156
task.execute(**kwargs)
156157
return True

src/_pytask/nodes.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,15 @@
1717
from _pytask.mark import Mark
1818

1919

20-
__all__ = ["FilePathNode", "MetaNode", "Task"]
20+
__all__ = ["FilePathNode", "MetaNode", "Product", "Task"]
21+
22+
23+
@define(frozen=True)
24+
class ProductType:
25+
"""A class to mark products."""
26+
27+
28+
Product = ProductType()
2129

2230

2331
class MetaNode(metaclass=ABCMeta):

0 commit comments

Comments
 (0)