diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 848759d..ff24424 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -18,20 +18,20 @@ on: jobs: - build-package: - name: Build & verify package + run-type-checking: + + name: Run tests for type-checking runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: - fetch-depth: 0 - - - uses: hynek/build-and-inspect-python-package@v2 - id: baipp - - outputs: - python-versions: ${{ steps.baipp.outputs.supported_python_classifiers_json_array }} + python-version-file: .python-version + allow-prereleases: true + cache: pip + - run: pip install tox-uv + - run: tox -e typing run-tests: @@ -59,16 +59,16 @@ jobs: shell: bash -l {0} run: tox -e test -- tests -m "unit or (not integration and not end_to_end)" --cov=src --cov=tests --cov-report=xml - - name: Upload coverage report for unit tests and doctests. - if: runner.os == 'Linux' && matrix.python-version == '3.10' - shell: bash -l {0} - run: bash <(curl -s https://codecov.io/bash) -F unit -c + - name: Upload unit test coverage reports to Codecov with GitHub Action + uses: codecov/codecov-action@v4 + with: + flags: unit - name: Run end-to-end tests. shell: bash -l {0} run: tox -e test -- tests -m end_to_end --cov=src --cov=tests --cov-report=xml - - name: Upload coverage reports of end-to-end tests. - if: runner.os == 'Linux' && matrix.python-version == '3.10' - shell: bash -l {0} - run: bash <(curl -s https://codecov.io/bash) -F end_to_end -c + - name: Upload end_to_end test coverage reports to Codecov with GitHub Action + uses: codecov/codecov-action@v4 + with: + flags: end_to_end diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e774b6d..a59c564 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: no-commit-to-branch args: [--branch, main] - repo: https://github.com/pre-commit/pygrep-hooks - rev: v1.10.0 # Use the ref you want to point at + rev: v1.10.0 hooks: - id: python-check-blanket-noqa - id: python-check-mock-methods @@ -44,39 +44,22 @@ repos: mdformat-black, mdformat-pyproject, ] + args: [--wrap, "88"] files: (docs/.) -# Conflicts with admonitions. -# - repo: https://github.com/executablebooks/mdformat -# rev: 0.7.17 -# hooks: -# - id: mdformat -# additional_dependencies: [ -# mdformat-gfm, -# mdformat-black, -# ] -# args: [--wrap, "88"] +- repo: https://github.com/executablebooks/mdformat + rev: 0.7.17 + hooks: + - id: mdformat + additional_dependencies: [ + mdformat-gfm, + mdformat-black, + ] + args: [--wrap, "88"] + files: (README\.md) - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: - id: codespell -- repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.10.0' - hooks: - - id: mypy - args: [ - --no-strict-optional, - --ignore-missing-imports, - ] - additional_dependencies: [ - attrs, - cloudpickle, - loky, - "git+https://github.com/pytask-dev/pytask.git@main", - rich, - types-click, - types-setuptools, - ] - pass_filenames: false - repo: meta hooks: - id: check-hooks-apply diff --git a/docs/source/changes.md b/docs/source/changes.md index c7215bd..bcaa1a6 100644 --- a/docs/source/changes.md +++ b/docs/source/changes.md @@ -23,6 +23,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and interactions with adaptive scaling. {issue}`98` does handle the resulting issues: no strong adherence to priorities, no pending status. - {pull}`100` adds project management with rye. +- {pull}`101` adds syncing for local paths as dependencies or products in remote + environments with the same OS. ## 0.4.1 - 2024-01-12 diff --git a/docs/source/coiled.md b/docs/source/coiled.md index d40b475..5463f7b 100644 --- a/docs/source/coiled.md +++ b/docs/source/coiled.md @@ -66,6 +66,14 @@ configure the hardware and software environment. ```{literalinclude} ../../docs_src/coiled/coiled_functions_task.py ``` +By default, {func}`@coiled.function ` +[scales adaptively](https://docs.coiled.io/user_guide/usage/functions/index.html#adaptive-scaling) +to the workload. It means that coiled infers from the number of submitted tasks and +previous runtimes, how many additional remote workers it should deploy to handle the +workload. It provides a convenient mechanism to scale without intervention. Also, +workers launched by {func}`@coiled.function ` will shutdown quicker +than a cluster. + ```{seealso} Serverless functions are more thoroughly explained in [coiled's guide](https://docs.coiled.io/user_guide/usage/functions/index.html). diff --git a/docs/source/dask.md b/docs/source/dask.md index 99eabea..0f023ea 100644 --- a/docs/source/dask.md +++ b/docs/source/dask.md @@ -83,9 +83,9 @@ You can find more information in the documentation for ## Remote -You can learn how to deploy your tasks to a remote dask cluster in [this -guide](https://docs.dask.org/en/stable/deploying.html). They recommend to use coiled for -deployment to cloud providers. +You can learn how to deploy your tasks to a remote dask cluster in +[this guide](https://docs.dask.org/en/stable/deploying.html). They recommend to use +coiled for deployment to cloud providers. [coiled](https://www.coiled.io/) is a product built on top of dask that eases the deployment of your workflow to many cloud providers like AWS, GCP, and Azure. @@ -93,5 +93,5 @@ deployment of your workflow to many cloud providers like AWS, GCP, and Azure. If you want to run the tasks in your project on a cluster managed by coiled read {ref}`this guide `. -Otherwise, follow the instructions in [dask's -guide](https://docs.dask.org/en/stable/deploying.html). +Otherwise, follow the instructions in +[dask's guide](https://docs.dask.org/en/stable/deploying.html). diff --git a/docs/source/developers_guide.md b/docs/source/developers_guide.md index f2fa790..b0a1c9b 100644 --- a/docs/source/developers_guide.md +++ b/docs/source/developers_guide.md @@ -1,10 +1,9 @@ # Developer's Guide `pytask-parallel` does not call the `pytask_execute_task_protocol` hook -specification/entry-point because `pytask_execute_task_setup` and -`pytask_execute_task` need to be separated from `pytask_execute_task_teardown`. Thus, -plugins that change this hook specification may not interact well with the -parallelization. +specification/entry-point because `pytask_execute_task_setup` and `pytask_execute_task` +need to be separated from `pytask_execute_task_teardown`. Thus, plugins that change this +hook specification may not interact well with the parallelization. Two PRs for CPython try to re-enable setting custom reducers which should have been working but does not. Here are the references. diff --git a/docs/source/index.md b/docs/source/index.md index d362a0f..a16ed8b 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -26,6 +26,7 @@ quickstart coiled dask custom_executors +remote_backends developers_guide changes On Github diff --git a/docs/source/remote_backends.md b/docs/source/remote_backends.md new file mode 100644 index 0000000..3a2ffbe --- /dev/null +++ b/docs/source/remote_backends.md @@ -0,0 +1,47 @@ +# Remote backends + +There are a couple of things you need to know when using backends that launch workers +remotely, meaning not on your machine. + +## Cross-platform support + +Issue: {issue}`102`. + +Currently, it is not possible to run tasks in a remote environment that has a different +OS than your local system. The reason is that when pytask sends data to the remote +worker, the data contains path objects, {class}`pathlib.WindowsPath` or +{class}`pathlib.PosixPath`, which cannot be unpickled on a different system. + +In general, remote machines are Unix machines which means people running Unix systems +themselves like Linux and MacOS should have no problems. + +Windows users on the other hand should use the +[WSL (Windows Subsystem for Linux)](https://learn.microsoft.com/en-us/windows/wsl/about) +to run their projects. + +## Local files + +Avoid using local files with remote backends and use storages like S3 for dependencies +and products. The reason is that every local file needs to be send to the remote workers +and when your internet connection is slow you will face a hefty penalty on runtime. + +## Local paths + +In most projects you are using local paths to refer to dependencies and products of your +tasks. This becomes an interesting problem with remote workers since your local files +are not necessarily available in the remote machine. + +pytask-parallel does its best to sync files before the execution to the worker, so you +can run your tasks locally and remotely without changing a thing. + +In case you create a file on the remote machine, the product will be synced back to your +local machine as well. + +It is still necessary to know that the remote paths are temporary files that share the +same file extension as the local file, but the name and path will be different. Do not +rely on them. + +Another way to circumvent the problem is to first define a local task that stores all +your necessary files in a remote storage like S3. In the remaining tasks, you can then +use paths pointing to the bucket instead of the local machine. See the +[guide on remote files](https://tinyurl.com/pytask-remote) for more explanations. diff --git a/pyproject.toml b/pyproject.toml index 2b3bf6e..e99a9fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,21 +55,12 @@ content-type = "text/markdown" text = "MIT" [project.urls] -Homepage = "https://github.com/pytask-dev/pytask-parallel" -Changelog = "https://github.com/pytask-dev/pytask-parallel/blob/main/CHANGES.md" -Documentation = "https://github.com/pytask-dev/pytask-parallel" +Homepage = "https://pytask-parallel.readthedocs.io/" +Changelog = "https://pytask-parallel.readthedocs.io/en/latest/changes.html" +Documentation = "https://pytask-parallel.readthedocs.io/" Github = "https://github.com/pytask-dev/pytask-parallel" Tracker = "https://github.com/pytask-dev/pytask-parallel/issues" -[tool.setuptools] -include-package-data = true -zip-safe = false -platforms = ["any"] -license-files = ["LICENSE"] - -[tool.check-manifest] -ignore = ["src/pytask_parallel/_version.py"] - [project.entry-points.pytask] pytask_parallel = "pytask_parallel.plugin" @@ -113,6 +104,7 @@ disallow_untyped_defs = true no_implicit_optional = true warn_redundant_casts = true warn_unused_ignores = true +ignore_missing_imports = true [[tool.mypy.overrides]] module = "tests.*" @@ -127,6 +119,7 @@ unsafe-fixes = true [tool.ruff.lint] extend-ignore = [ "ANN101", # type annotating self + "ANN102", # type annotating cls "ANN401", # flake8-annotate typing.Any "COM812", # Comply with ruff-format. "ISC001", # Comply with ruff-format. diff --git a/src/pytask_parallel/backends.py b/src/pytask_parallel/backends.py index efa77cf..7fb947d 100644 --- a/src/pytask_parallel/backends.py +++ b/src/pytask_parallel/backends.py @@ -46,10 +46,11 @@ def submit( # type: ignore[override] def _get_dask_executor(n_workers: int) -> Executor: """Get an executor from a dask client.""" - _rich_traceback_omit = True + _rich_traceback_guard = True from pytask import import_optional_dependency distributed = import_optional_dependency("distributed") + assert distributed # noqa: S101 try: client = distributed.Client.current() except ValueError: diff --git a/src/pytask_parallel/config.py b/src/pytask_parallel/config.py index 0b6daf6..e39440b 100644 --- a/src/pytask_parallel/config.py +++ b/src/pytask_parallel/config.py @@ -27,7 +27,7 @@ def pytask_parse_config(config: dict[str, Any]) -> None: raise ValueError(msg) from None if config["n_workers"] == "auto": - config["n_workers"] = max(os.cpu_count() - 1, 1) + config["n_workers"] = max(os.cpu_count() - 1, 1) # type: ignore[operator] # If more than one worker is used, and no backend is set, use loky. if config["n_workers"] > 1 and config["parallel_backend"] == ParallelBackend.NONE: diff --git a/src/pytask_parallel/execute.py b/src/pytask_parallel/execute.py index 093ec9c..1ee0d71 100644 --- a/src/pytask_parallel/execute.py +++ b/src/pytask_parallel/execute.py @@ -8,6 +8,7 @@ from typing import Any import cloudpickle +from _pytask.node_protocols import PPathNode from attrs import define from attrs import field from pytask import ExecutionReport @@ -24,9 +25,10 @@ from pytask_parallel.backends import WorkerType from pytask_parallel.backends import registry +from pytask_parallel.typing import CarryOverPath +from pytask_parallel.typing import is_coiled_function from pytask_parallel.utils import create_kwargs_for_task from pytask_parallel.utils import get_module -from pytask_parallel.utils import is_coiled_function from pytask_parallel.utils import parse_future_result if TYPE_CHECKING: @@ -222,7 +224,7 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]: from pytask_parallel.wrappers import wrap_task_in_thread return session.config["_parallel_executor"].submit( - wrap_task_in_thread, task=task, **kwargs + wrap_task_in_thread, task=task, remote=False, **kwargs ) msg = f"Unknown worker type {worker_type}" raise ValueError(msg) @@ -235,19 +237,33 @@ def pytask_unconfigure() -> None: def _update_carry_over_products( - task: PTask, carry_over_products: PyTree[PythonNode | None] | None + task: PTask, carry_over_products: PyTree[CarryOverPath | PythonNode | None] | None ) -> None: - """Update products carry over from a another process or remote worker.""" + """Update products carry over from a another process or remote worker. - def _update_carry_over_node(x: PNode, y: PythonNode | None) -> PNode: - if y: + The python node can be a regular one passing the value to another python node. + + In other instances the python holds a string or bytes from a RemotePathNode. + + """ + + def _update_carry_over_node( + x: PNode, y: CarryOverPath | PythonNode | None + ) -> PNode: + if y is None: + return x + if isinstance(x, PPathNode) and isinstance(y, CarryOverPath): + x.path.write_bytes(y.content) + return x + if isinstance(y, PythonNode): x.save(y.load()) - return x + return x + raise NotImplementedError - structure_python_nodes = tree_structure(carry_over_products) + structure_carry_over_products = tree_structure(carry_over_products) structure_produces = tree_structure(task.produces) # strict must be false when none is leaf. - if structure_produces.is_prefix(structure_python_nodes, strict=False): + if structure_produces.is_prefix(structure_carry_over_products, strict=False): task.produces = tree_map( _update_carry_over_node, task.produces, carry_over_products ) # type: ignore[assignment] diff --git a/src/pytask_parallel/nodes.py b/src/pytask_parallel/nodes.py new file mode 100644 index 0000000..e5fb0f6 --- /dev/null +++ b/src/pytask_parallel/nodes.py @@ -0,0 +1,76 @@ +"""Contains nodes for pytask-parallel.""" + +from __future__ import annotations + +import tempfile +from pathlib import Path +from typing import Any + +from _pytask.node_protocols import PNode +from _pytask.node_protocols import PPathNode +from attrs import define + +from pytask_parallel.typing import is_local_path + + +@define(kw_only=True) +class RemotePathNode(PNode): + """A class to handle path nodes with local paths in remote environments. + + Tasks may use nodes, following :class:`pytask.PPathNode`, with paths pointing to + local files. These local files should be automatically available in remote + environments so that users do not have to care about running their tasks locally or + remotely. + + The :class:`RemotePathNode` allows to send local files over to remote environments + and back. + + """ + + name: str + node: PPathNode + signature: str + value: Any + is_product: bool + remote_path: str = "" + fd: int = -1 + + @classmethod + def from_path_node(cls, node: PPathNode, *, is_product: bool) -> RemotePathNode: + """Instantiate class from path node.""" + if not is_local_path(node.path): + msg = "Path is not a local path and does not need to be fixed" + raise ValueError(msg) + + value = b"" if is_product else node.path.read_bytes() + + return cls( + name=node.name, + node=node, + signature=node.signature, + value=value, + is_product=is_product, + ) + + def state(self) -> str | None: + """Calculate the state of the node.""" + msg = "RemotePathNode does not implement .state()." + raise NotImplementedError(msg) + + def load(self, is_product: bool = False) -> Path: # noqa: ARG002, FBT001, FBT002 + """Load the value.""" + # Create a temporary file to store the value. + self.fd, self.remote_path = tempfile.mkstemp(suffix=self.node.path.name) + path = Path(self.remote_path) + + # If the file is a dependency, store the value in the file. + if not self.is_product: + path.write_bytes(self.value) + + # Patch path in original node and load the node. + self.node.path = path + return self.node.load(is_product=self.is_product) + + def save(self, value: Any) -> None: + """Save strings or bytes to file.""" + self.value = value diff --git a/src/pytask_parallel/typing.py b/src/pytask_parallel/typing.py new file mode 100644 index 0000000..5d0ac4b --- /dev/null +++ b/src/pytask_parallel/typing.py @@ -0,0 +1,25 @@ +"""Contains functions related to typing.""" + +from pathlib import Path +from pathlib import PosixPath +from pathlib import WindowsPath +from typing import NamedTuple + +from pytask import PTask +from upath.implementations.local import FilePath + +__all__ = ["is_coiled_function", "is_local_path"] + + +def is_coiled_function(task: PTask) -> bool: + """Check if a function is a coiled function.""" + return "coiled_kwargs" in task.attributes + + +def is_local_path(path: Path) -> bool: + """Check if a path is local.""" + return isinstance(path, (FilePath, PosixPath, WindowsPath)) + + +class CarryOverPath(NamedTuple): + content: bytes diff --git a/src/pytask_parallel/utils.py b/src/pytask_parallel/utils.py index a551cc8..2ab50f5 100644 --- a/src/pytask_parallel/utils.py +++ b/src/pytask_parallel/utils.py @@ -4,19 +4,19 @@ import inspect from functools import partial -from pathlib import PosixPath -from pathlib import WindowsPath from typing import TYPE_CHECKING from typing import Any from typing import Callable from pytask import NodeLoadError -from pytask import PathNode from pytask import PNode +from pytask import PPathNode from pytask import PProvisionalNode from pytask.tree_util import PyTree from pytask.tree_util import tree_map_with_path -from upath.implementations.local import FilePath + +from pytask_parallel.nodes import RemotePathNode +from pytask_parallel.typing import is_local_path if TYPE_CHECKING: from concurrent.futures import Future @@ -39,8 +39,6 @@ class CoiledFunction: ... # type: ignore[no-redef] "create_kwargs_for_task", "get_module", "parse_future_result", - "is_local_path", - "is_coiled_function", ] @@ -56,7 +54,7 @@ def parse_future_result( exc_info = _parse_future_exception(future_exception) return WrapperResult( - carry_over_products=None, + carry_over_products=None, # type: ignore[arg-type] warning_reports=[], exc_info=exc_info, stdout="", @@ -78,29 +76,14 @@ def _safe_load( # Get the argument name like "path" or "return" for function returns. argument = path[0] - # Raise an error if a PPathNode with a local path is used as a dependency or product - # (except as a return value). + # Replace local path nodes with remote path nodes if necessary. if ( remote and argument != "return" - and isinstance(node, PathNode) + and isinstance(node, PPathNode) and is_local_path(node.path) ): - if is_product: - msg = ( - f"You cannot use a local path as a product in argument {argument!r} " - "with a remote backend. Either return the content that should be saved " - "in the file with a return annotation " - "(https://tinyurl.com/pytask-return) or use a nonlocal path to store " - "the file in S3 or their like https://tinyurl.com/pytask-remote." - ) - raise NodeLoadError(msg) - msg = ( - f"You cannot use a local path as a dependency in argument {argument!r} " - "with a remote backend. Upload the file to a remote storage like S3 " - "and use the remote path instead: https://tinyurl.com/pytask-remote." - ) - raise NodeLoadError(msg) + return RemotePathNode.from_path_node(node, is_product=is_product) try: return node.load(is_product=is_product) @@ -145,7 +128,7 @@ def create_kwargs_for_task(task: PTask, *, remote: bool) -> dict[str, PyTree[Any def _parse_future_exception( exc: BaseException | None, -) -> tuple[type[BaseException], BaseException, TracebackType] | None: +) -> tuple[type[BaseException], BaseException, TracebackType | None] | None: """Parse a future exception into the format of ``sys.exc_info``.""" return None if exc is None else (type(exc), exc, exc.__traceback__) @@ -165,15 +148,5 @@ def get_module(func: Callable[..., Any], path: Path | None) -> ModuleType: func = func.func if path: - return inspect.getmodule(func, path.as_posix()) - return inspect.getmodule(func) - - -def is_local_path(path: Path) -> bool: - """Check if a path is local.""" - return isinstance(path, (FilePath, PosixPath, WindowsPath)) - - -def is_coiled_function(task: PTask) -> bool: - """Check if a function is a coiled function.""" - return "coiled_kwargs" in task.attributes + return inspect.getmodule(func, path.as_posix()) # type: ignore[return-value] + return inspect.getmodule(func) # type: ignore[return-value] diff --git a/src/pytask_parallel/wrappers.py b/src/pytask_parallel/wrappers.py index b7f8a33..6157873 100644 --- a/src/pytask_parallel/wrappers.py +++ b/src/pytask_parallel/wrappers.py @@ -3,11 +3,14 @@ from __future__ import annotations import functools +import os import sys import warnings from contextlib import redirect_stderr from contextlib import redirect_stdout +from contextlib import suppress from io import StringIO +from pathlib import Path from typing import TYPE_CHECKING from typing import Any @@ -22,11 +25,14 @@ from pytask import parse_warning_filter from pytask import warning_record_to_str from pytask.tree_util import PyTree +from pytask.tree_util import tree_map from pytask.tree_util import tree_map_with_path from pytask.tree_util import tree_structure +from pytask_parallel.nodes import RemotePathNode +from pytask_parallel.typing import CarryOverPath +from pytask_parallel.typing import is_local_path from pytask_parallel.utils import CoiledFunction -from pytask_parallel.utils import is_local_path if TYPE_CHECKING: from types import TracebackType @@ -40,14 +46,16 @@ @define(kw_only=True) class WrapperResult: - carry_over_products: PyTree[PythonNode | None] + carry_over_products: PyTree[CarryOverPath | PythonNode | None] warning_reports: list[WarningReport] - exc_info: tuple[type[BaseException], BaseException, TracebackType | str] | None + exc_info: ( + tuple[type[BaseException], BaseException, TracebackType | str | None] | None + ) stdout: str stderr: str -def wrap_task_in_thread(task: PTask, **kwargs: Any) -> WrapperResult: +def wrap_task_in_thread(task: PTask, *, remote: bool, **kwargs: Any) -> WrapperResult: """Mock execution function such that it returns the same as for processes. The function for processes returns ``warning_reports`` and an ``exception``. With @@ -61,12 +69,12 @@ def wrap_task_in_thread(task: PTask, **kwargs: Any) -> WrapperResult: except Exception: # noqa: BLE001 exc_info = sys.exc_info() else: - _handle_function_products(task, out, remote=False) - exc_info = None + _handle_function_products(task, out, remote=remote) + exc_info = None # type: ignore[assignment] return WrapperResult( - carry_over_products=None, + carry_over_products=None, # type: ignore[arg-type] warning_reports=[], - exc_info=exc_info, + exc_info=exc_info, # type: ignore[arg-type] stdout="", stderr="", ) @@ -110,19 +118,25 @@ def wrap_task_in_process( # noqa: PLR0913 for arg in mark.args: warnings.filterwarnings(*parse_warning_filter(arg, escape=False)) + processed_exc_info: tuple[type[BaseException], BaseException, str] | None + try: - out = task.execute(**kwargs) + resolved_kwargs = _write_local_files_to_remote(kwargs) + out = task.execute(**resolved_kwargs) except Exception: # noqa: BLE001 exc_info = sys.exc_info() processed_exc_info = _render_traceback_to_string( - exc_info, show_locals, console_options + exc_info, # type: ignore[arg-type] + show_locals, + console_options, ) products = None else: - # Save products. products = _handle_function_products(task, out, remote=remote) processed_exc_info = None + _delete_local_files_on_remote(kwargs) + task_display_name = getattr(task, "display_name", task.name) warning_reports = [] for warning_message in log: @@ -143,7 +157,7 @@ def wrap_task_in_process( # noqa: PLR0913 captured_stderr_buffer.close() return WrapperResult( - carry_over_products=products, + carry_over_products=products, # type: ignore[arg-type] warning_reports=warning_reports, exc_info=processed_exc_info, stdout=captured_stdout, @@ -193,22 +207,15 @@ def _render_traceback_to_string( def _handle_function_products( - task: PTask, out: Any, *, remote: bool -) -> PyTree[PythonNode | None]: + task: PTask, out: Any, *, remote: bool = False +) -> PyTree[CarryOverPath | PythonNode | None]: """Handle the products of the task. The functions first responsibility is to push the returns of the function to the defined nodes. - Secondly, the function collects two kinds of products that need to be carried over - to the main process for storing them. - - 1. Any product that is a :class:`~pytask.PythonNode` needs to be carried over to the - main process as otherwise their value would be lost. - 2. If the function is executed remotely and the return value should be stored in a - node with a local path like :class:`pytask.PickleNode`, we need to carry over the - value to the main process again and, then, save the value to the node as the - local path does not exist remotely. + Its second responsibility is to carry over products from remote to local + environments if the product is a :class:`PPathNode` with a local path. """ # Check that the return value has the correct structure. @@ -226,26 +233,33 @@ def _handle_function_products( def _save_and_carry_over_product( path: tuple[Any, ...], node: PNode - ) -> PythonNode | None: + ) -> CarryOverPath | PythonNode | None: argument = path[0] + # Handle the case when it is not a return annotation product. if argument != "return": if isinstance(node, PythonNode): return node + + # If the product was a local path and we are remote, we load the file + # content as bytes and carry it over. + if isinstance(node, PPathNode) and is_local_path(node.path) and remote: + return CarryOverPath(content=node.path.read_bytes()) return None + # If it is a return value annotation, index the return until we get the value. value = out for p in path[1:]: value = value[p] # If the node is a PythonNode, we need to carry it over to the main process. if isinstance(node, PythonNode): - node.save(value) + node.save(value=value) return node # If the path is local and we are remote, we need to carry over the value to # the main process as a PythonNode and save it later. - if remote and isinstance(node, PPathNode) and is_local_path(node.path): + if isinstance(node, PPathNode) and is_local_path(node.path) and remote: return PythonNode(value=value) # If no condition applies, we save the value and do not carry it over. Like a @@ -254,3 +268,32 @@ def _save_and_carry_over_product( return None return tree_map_with_path(_save_and_carry_over_product, task.produces) + + +def _write_local_files_to_remote( + kwargs: dict[str, PyTree[Any]], +) -> dict[str, PyTree[Any]]: + """Write local files to remote. + + The main process pushed over kwargs that might contain RemotePathNodes. These need + to be resolved. + + """ + return tree_map(lambda x: x.load() if isinstance(x, RemotePathNode) else x, kwargs) # type: ignore[return-value] + + +def _delete_local_files_on_remote(kwargs: dict[str, PyTree[Any]]) -> None: + """Delete local files on remote. + + Local files were copied over to the remote via RemotePathNodes. We need to delete + them after the task is executed. + + """ + + def _delete(potential_node: Any) -> None: + if isinstance(potential_node, RemotePathNode): + with suppress(OSError): + os.close(potential_node.fd) + Path(potential_node.remote_path).unlink(missing_ok=True) + + tree_map(_delete, kwargs) diff --git a/tests/test_remote.py b/tests/test_remote.py index cfff33c..366e30a 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -87,8 +87,9 @@ def task_example(path: Path = Path("in.txt")) -> Annotated[str, Path("output.txt tmp_path.joinpath("config.py").as_posix(), ], ) - assert result.exit_code == ExitCode.FAILED - assert "You cannot use a local path" in result.output + assert result.exit_code == ExitCode.OK + assert "1 Succeeded" in result.output + assert tmp_path.joinpath("output.txt").read_text() == "Hello World!" @pytest.mark.end_to_end() @@ -113,8 +114,9 @@ def task_example(path: Annotated[Path, Product] = Path("output.txt")): tmp_path.joinpath("config.py").as_posix(), ], ) - assert result.exit_code == ExitCode.FAILED - assert "You cannot use a local path" in result.output + assert result.exit_code == ExitCode.OK + assert "1 Succeeded" in result.output + assert tmp_path.joinpath("output.txt").read_text() == "Hello World!" @pytest.mark.end_to_end() diff --git a/tox.ini b/tox.ini index cb55d00..b2057a0 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,15 @@ envlist = test [testenv] package = editable +[testenv:typing] +deps = + mypy + git+https://github.com/pytask-dev/pytask.git@main +extras = + dask + coiled +commands = mypy + [testenv:test] extras = test deps =