Skip to content

Commit 54d7b87

Browse files
authored
Handle local paths in remote workers. (#96)
1 parent ef2fbca commit 54d7b87

File tree

7 files changed

+482
-58
lines changed

7 files changed

+482
-58
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@ build
1717
dist
1818
src/pytask_parallel/_version.py
1919
tests/test_jupyter/*.txt
20+
.mypy_cache
21+
.pytest_cache
22+
.ruff_cache

docs/source/changes.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
1717
- {pull}`94` implements `ParallelBackend.NONE` as the default backend.
1818
- {pull}`95` formalizes parallel backends and apply wrappers for backends with threads
1919
or processes automatically.
20+
- {pull}`96` handles local paths with remote executors. `PathNode`s are not supported as
21+
dependencies or products (except for return annotations).
2022

2123
## 0.4.1 - 2024-01-12
2224

src/pytask_parallel/execute.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,16 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
8080
session.hook.pytask_execute_task_setup(
8181
session=session, task=task
8282
)
83+
running_tasks[task_name] = session.hook.pytask_execute_task(
84+
session=session, task=task
85+
)
86+
sleeper.reset()
8387
except Exception: # noqa: BLE001
8488
report = ExecutionReport.from_task_and_exception(
8589
task, sys.exc_info()
8690
)
8791
newly_collected_reports.append(report)
8892
session.scheduler.done(task_name)
89-
else:
90-
running_tasks[task_name] = session.hook.pytask_execute_task(
91-
session=session, task=task
92-
)
93-
sleeper.reset()
9493

9594
if not ready_tasks:
9695
sleeper.increment()
@@ -123,7 +122,9 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
123122
session.scheduler.done(task_name)
124123
else:
125124
task = session.dag.nodes[task_name]["task"]
126-
_update_python_nodes(task, wrapper_result.python_nodes)
125+
_update_carry_over_products(
126+
task, wrapper_result.carry_over_products
127+
)
127128

128129
try:
129130
session.hook.pytask_execute_task_teardown(
@@ -169,9 +170,11 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
169170
executor.
170171
171172
"""
172-
worker_type = registry.registry[session.config["parallel_backend"]].worker_type
173+
parallel_backend = registry.registry[session.config["parallel_backend"]]
174+
worker_type = parallel_backend.worker_type
175+
remote = parallel_backend.remote
173176

174-
kwargs = create_kwargs_for_task(task)
177+
kwargs = create_kwargs_for_task(task, remote=remote)
175178

176179
if worker_type == WorkerType.PROCESSES:
177180
# Prevent circular import for loky backend.
@@ -188,10 +191,11 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
188191
return session.config["_parallel_executor"].submit(
189192
wrap_task_in_process,
190193
task=task,
191-
kwargs=kwargs,
192-
show_locals=session.config["show_locals"],
193194
console_options=console.options,
195+
kwargs=kwargs,
196+
remote=remote,
194197
session_filterwarnings=session.config["filterwarnings"],
198+
show_locals=session.config["show_locals"],
195199
task_filterwarnings=get_marks(task, "filterwarnings"),
196200
)
197201
if worker_type == WorkerType.THREADS:
@@ -211,21 +215,23 @@ def pytask_unconfigure() -> None:
211215
registry.reset()
212216

213217

214-
def _update_python_nodes(
215-
task: PTask, python_nodes: PyTree[PythonNode | None] | None
218+
def _update_carry_over_products(
219+
task: PTask, carry_over_products: PyTree[PythonNode | None] | None
216220
) -> None:
217-
"""Update the python nodes of a task with the python nodes from the future."""
221+
"""Update products carry over from a another process or remote worker."""
218222

219-
def _update_python_node(x: PNode, y: PythonNode | None) -> PNode:
223+
def _update_carry_over_node(x: PNode, y: PythonNode | None) -> PNode:
220224
if y:
221225
x.save(y.load())
222226
return x
223227

224-
structure_python_nodes = tree_structure(python_nodes)
228+
structure_python_nodes = tree_structure(carry_over_products)
225229
structure_produces = tree_structure(task.produces)
226230
# strict must be false when none is leaf.
227231
if structure_produces.is_prefix(structure_python_nodes, strict=False):
228-
task.produces = tree_map(_update_python_node, task.produces, python_nodes) # type: ignore[assignment]
232+
task.produces = tree_map(
233+
_update_carry_over_node, task.produces, carry_over_products
234+
) # type: ignore[assignment]
229235

230236

231237
@define(kw_only=True)

src/pytask_parallel/utils.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,19 @@
44

55
import inspect
66
from functools import partial
7+
from pathlib import PosixPath
8+
from pathlib import WindowsPath
79
from typing import TYPE_CHECKING
810
from typing import Any
911
from typing import Callable
1012

13+
from pytask import NodeLoadError
14+
from pytask import PathNode
15+
from pytask import PNode
16+
from pytask import PProvisionalNode
1117
from pytask.tree_util import PyTree
12-
from pytask.tree_util import tree_map
18+
from pytask.tree_util import tree_map_with_path
19+
from upath.implementations.local import FilePath
1320

1421
if TYPE_CHECKING:
1522
from concurrent.futures import Future
@@ -22,7 +29,12 @@
2229
from pytask_parallel.wrappers import WrapperResult
2330

2431

25-
__all__ = ["create_kwargs_for_task", "get_module", "parse_future_result"]
32+
__all__ = [
33+
"create_kwargs_for_task",
34+
"get_module",
35+
"parse_future_result",
36+
"is_local_path",
37+
]
2638

2739

2840
def parse_future_result(
@@ -32,11 +44,12 @@ def parse_future_result(
3244
# An exception was raised before the task was executed.
3345
future_exception = future.exception()
3446
if future_exception is not None:
47+
# Prevent circular import for loky backend.
3548
from pytask_parallel.wrappers import WrapperResult
3649

3750
exc_info = _parse_future_exception(future_exception)
3851
return WrapperResult(
39-
python_nodes=None,
52+
carry_over_products=None,
4053
warning_reports=[],
4154
exc_info=exc_info,
4255
stdout="",
@@ -45,17 +58,80 @@ def parse_future_result(
4558
return future.result()
4659

4760

48-
def create_kwargs_for_task(task: PTask) -> dict[str, PyTree[Any]]:
61+
def _safe_load(
62+
path: tuple[Any, ...],
63+
node: PNode | PProvisionalNode,
64+
task: PTask,
65+
*,
66+
is_product: bool,
67+
remote: bool,
68+
) -> Any:
69+
"""Load a node and catch exceptions."""
70+
_rich_traceback_guard = True
71+
# Get the argument name like "path" or "return" for function returns.
72+
argument = path[0]
73+
74+
# Raise an error if a PPathNode with a local path is used as a dependency or product
75+
# (except as a return value).
76+
if (
77+
remote
78+
and argument != "return"
79+
and isinstance(node, PathNode)
80+
and is_local_path(node.path)
81+
):
82+
if is_product:
83+
msg = (
84+
f"You cannot use a local path as a product in argument {argument!r} "
85+
"with a remote backend. Either return the content that should be saved "
86+
"in the file with a return annotation "
87+
"(https://tinyurl.com/pytask-return) or use a nonlocal path to store "
88+
"the file in S3 or their like https://tinyurl.com/pytask-remote."
89+
)
90+
raise NodeLoadError(msg)
91+
msg = (
92+
f"You cannot use a local path as a dependency in argument {argument!r} "
93+
"with a remote backend. Upload the file to a remote storage like S3 "
94+
"and use the remote path instead: https://tinyurl.com/pytask-remote."
95+
)
96+
raise NodeLoadError(msg)
97+
98+
try:
99+
return node.load(is_product=is_product)
100+
except Exception as e: # noqa: BLE001
101+
msg = f"Exception while loading node {node.name!r} of task {task.name!r}"
102+
raise NodeLoadError(msg) from e
103+
104+
105+
def create_kwargs_for_task(task: PTask, *, remote: bool) -> dict[str, PyTree[Any]]:
49106
"""Create kwargs for task function."""
50107
parameters = inspect.signature(task.function).parameters
51108

52109
kwargs = {}
110+
53111
for name, value in task.depends_on.items():
54-
kwargs[name] = tree_map(lambda x: x.load(), value)
112+
kwargs[name] = tree_map_with_path(
113+
lambda p, x: _safe_load(
114+
(name, *p), # noqa: B023
115+
x,
116+
task,
117+
is_product=False,
118+
remote=remote,
119+
),
120+
value,
121+
)
55122

56123
for name, value in task.produces.items():
57124
if name in parameters:
58-
kwargs[name] = tree_map(lambda x: x.load(), value)
125+
kwargs[name] = tree_map_with_path(
126+
lambda p, x: _safe_load(
127+
(name, *p), # noqa: B023
128+
x,
129+
task,
130+
is_product=True,
131+
remote=remote,
132+
),
133+
value,
134+
)
59135

60136
return kwargs
61137

@@ -84,3 +160,8 @@ def get_module(func: Callable[..., Any], path: Path | None) -> ModuleType:
84160
if path:
85161
return inspect.getmodule(func, path.as_posix())
86162
return inspect.getmodule(func)
163+
164+
165+
def is_local_path(path: Path) -> bool:
166+
"""Check if a path is local."""
167+
return isinstance(path, (FilePath, PosixPath, WindowsPath))

0 commit comments

Comments
 (0)