Skip to content

Commit aed5517

Browse files
committed
Fix types.
1 parent 316355e commit aed5517

File tree

12 files changed

+108
-110
lines changed

12 files changed

+108
-110
lines changed

.github/workflows/main.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,20 @@ on:
1818

1919
jobs:
2020

21-
build-package:
22-
name: Build & verify package
21+
run-type-checking:
22+
23+
name: Run tests for type-checking
2324
runs-on: ubuntu-latest
2425

2526
steps:
2627
- uses: actions/checkout@v4
28+
- uses: actions/setup-python@v5
2729
with:
28-
fetch-depth: 0
29-
30-
- uses: hynek/build-and-inspect-python-package@v2
31-
id: baipp
32-
33-
outputs:
34-
python-versions: ${{ steps.baipp.outputs.supported_python_classifiers_json_array }}
30+
python-version-file: .python-version
31+
allow-prereleases: true
32+
cache: pip
33+
- run: pip install tox-uv
34+
- run: tox -e typing
3535

3636
run-tests:
3737

.pre-commit-config.yaml

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,6 @@ repos:
5959
rev: v2.2.6
6060
hooks:
6161
- id: codespell
62-
- repo: https://github.com/pre-commit/mirrors-mypy
63-
rev: 'v1.9.0'
64-
hooks:
65-
- id: mypy
66-
args: [
67-
--no-strict-optional,
68-
--ignore-missing-imports,
69-
]
70-
additional_dependencies: [
71-
attrs,
72-
cloudpickle,
73-
loky,
74-
"git+https://github.com/pytask-dev/pytask.git@main",
75-
rich,
76-
types-click,
77-
types-setuptools,
78-
]
79-
pass_filenames: false
8062
- repo: meta
8163
hooks:
8264
- id: check-hooks-apply

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ disallow_untyped_defs = true
104104
no_implicit_optional = true
105105
warn_redundant_casts = true
106106
warn_unused_ignores = true
107+
ignore_missing_imports = true
107108

108109
[[tool.mypy.overrides]]
109110
module = "tests.*"

src/pytask_parallel/backends.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def _get_dask_executor(n_workers: int) -> Executor:
5050
from pytask import import_optional_dependency
5151

5252
distributed = import_optional_dependency("distributed")
53+
assert distributed # noqa: S101
5354
try:
5455
client = distributed.Client.current()
5556
except ValueError:

src/pytask_parallel/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def pytask_parse_config(config: dict[str, Any]) -> None:
2727
raise ValueError(msg) from None
2828

2929
if config["n_workers"] == "auto":
30-
config["n_workers"] = max(os.cpu_count() - 1, 1)
30+
config["n_workers"] = max(os.cpu_count() - 1, 1) # type: ignore[operator]
3131

3232
# If more than one worker is used, and no backend is set, use loky.
3333
if config["n_workers"] > 1 and config["parallel_backend"] == ParallelBackend.NONE:

src/pytask_parallel/execute.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any
99

1010
import cloudpickle
11+
from _pytask.node_protocols import PPathNode
1112
from attrs import define
1213
from attrs import field
1314
from pytask import ExecutionReport
@@ -24,6 +25,7 @@
2425

2526
from pytask_parallel.backends import WorkerType
2627
from pytask_parallel.backends import registry
28+
from pytask_parallel.typing import CarryOverPath
2729
from pytask_parallel.typing import is_coiled_function
2830
from pytask_parallel.utils import create_kwargs_for_task
2931
from pytask_parallel.utils import get_module
@@ -235,7 +237,7 @@ def pytask_unconfigure() -> None:
235237

236238

237239
def _update_carry_over_products(
238-
task: PTask, carry_over_products: PyTree[PythonNode | None] | None
240+
task: PTask, carry_over_products: PyTree[CarryOverPath | PythonNode | None] | None
239241
) -> None:
240242
"""Update products carry over from a another process or remote worker.
241243
@@ -245,10 +247,18 @@ def _update_carry_over_products(
245247
246248
"""
247249

248-
def _update_carry_over_node(x: PNode, y: PythonNode | None) -> PNode:
249-
if y:
250+
def _update_carry_over_node(
251+
x: PNode, y: CarryOverPath | PythonNode | None
252+
) -> PNode:
253+
if y is None:
254+
return x
255+
if isinstance(x, PPathNode) and isinstance(y, CarryOverPath):
256+
x.path.write_bytes(y.content)
257+
return x
258+
if isinstance(y, PythonNode):
250259
x.save(y.load())
251-
return x
260+
return x
261+
raise NotImplementedError
252262

253263
structure_carry_over_products = tree_structure(carry_over_products)
254264
structure_produces = tree_structure(task.produces)

src/pytask_parallel/nodes.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,41 @@
22

33
from __future__ import annotations
44

5-
import os
65
import tempfile
76
from pathlib import Path
8-
from typing import TYPE_CHECKING
7+
from typing import Any
98

109
from _pytask.node_protocols import PNode
10+
from _pytask.node_protocols import PPathNode
1111
from attrs import define
1212

1313
from pytask_parallel.typing import is_local_path
1414

15-
if TYPE_CHECKING:
16-
from pytask import PathNode
17-
1815

1916
@define(kw_only=True)
2017
class RemotePathNode(PNode):
21-
"""The class for a node which is a path."""
18+
"""A class to handle path nodes with local paths in remote environments.
19+
20+
Tasks may use nodes, following :class:`pytask.PPathNode`, with paths pointing to
21+
local files. These local files should be automatically available in remote
22+
environments so that users do not have to care about running their tasks locally or
23+
remotely.
24+
25+
The :class:`RemotePathNode` allows to send local files over to remote environments
26+
and back.
27+
28+
"""
2229

2330
name: str
24-
local_path: str
31+
node: PPathNode
2532
signature: str
26-
value: str | bytes
33+
value: Any
34+
is_product: bool
2735
remote_path: str = ""
2836
fd: int = -1
2937

3038
@classmethod
31-
def from_path_node(cls, node: PathNode, *, is_product: bool) -> RemotePathNode:
39+
def from_path_node(cls, node: PPathNode, *, is_product: bool) -> RemotePathNode:
3240
"""Instantiate class from path node."""
3341
if not is_local_path(node.path):
3442
msg = "Path is not a local path and does not need to be fixed"
@@ -38,34 +46,31 @@ def from_path_node(cls, node: PathNode, *, is_product: bool) -> RemotePathNode:
3846

3947
return cls(
4048
name=node.name,
41-
local_path=node.path.as_posix(),
49+
node=node,
4250
signature=node.signature,
4351
value=value,
52+
is_product=is_product,
4453
)
4554

4655
def state(self) -> str | None:
4756
"""Calculate the state of the node."""
4857
msg = "RemotePathNode does not implement .state()."
4958
raise NotImplementedError(msg)
5059

51-
def load(self, is_product: bool = False) -> Path: # noqa: FBT001, FBT002
60+
def load(self, is_product: bool = False) -> Path: # noqa: ARG002, FBT001, FBT002
5261
"""Load the value."""
5362
# Create a temporary file to store the value.
54-
ext = os.path.splitext(self.local_path)[1] # noqa: PTH122
55-
self.fd, self.remote_path = tempfile.mkstemp(suffix=ext)
63+
self.fd, self.remote_path = tempfile.mkstemp(suffix=self.node.path.name)
64+
path = Path(self.remote_path)
5665

5766
# If the file is a dependency, store the value in the file.
58-
path = Path(self.remote_path)
59-
if not is_product:
60-
path.write_text(self.value) if isinstance(
61-
self.value, str
62-
) else path.write_bytes(self.value)
63-
return path
67+
if not self.is_product:
68+
path.write_bytes(self.value)
69+
70+
# Patch path in original node and load the node.
71+
self.node.path = path
72+
return self.node.load(is_product=self.is_product)
6473

65-
def save(self, value: bytes | str) -> None:
74+
def save(self, value: Any) -> None:
6675
"""Save strings or bytes to file."""
67-
if isinstance(value, (bytes, str)):
68-
self.value = value
69-
else:
70-
msg = f"'RemotePathNode' can only save 'str' and 'bytes', not {type(value)}"
71-
raise TypeError(msg)
76+
self.value = value

src/pytask_parallel/typing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
from pathlib import PosixPath
55
from pathlib import WindowsPath
6+
from typing import NamedTuple
67

78
from pytask import PTask
89
from upath.implementations.local import FilePath
@@ -18,3 +19,7 @@ def is_coiled_function(task: PTask) -> bool:
1819
def is_local_path(path: Path) -> bool:
1920
"""Check if a path is local."""
2021
return isinstance(path, (FilePath, PosixPath, WindowsPath))
22+
23+
24+
class CarryOverPath(NamedTuple):
25+
content: bytes

src/pytask_parallel/utils.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,9 @@
99
from typing import Callable
1010

1111
from pytask import NodeLoadError
12-
from pytask import PathNode
13-
from pytask import PickleNode
1412
from pytask import PNode
1513
from pytask import PPathNode
1614
from pytask import PProvisionalNode
17-
from pytask import PythonNode
1815
from pytask.tree_util import PyTree
1916
from pytask.tree_util import tree_map_with_path
2017

@@ -57,7 +54,7 @@ def parse_future_result(
5754

5855
exc_info = _parse_future_exception(future_exception)
5956
return WrapperResult(
60-
carry_over_products=None,
57+
carry_over_products=None, # type: ignore[arg-type]
6158
warning_reports=[],
6259
exc_info=exc_info,
6360
stdout="",
@@ -86,12 +83,7 @@ def _safe_load(
8683
and isinstance(node, PPathNode)
8784
and is_local_path(node.path)
8885
):
89-
if isinstance(node, PathNode):
90-
return RemotePathNode.from_path_node(node, is_product=is_product)
91-
if isinstance(node, PickleNode):
92-
if is_product:
93-
return PythonNode()
94-
return node.load()
86+
return RemotePathNode.from_path_node(node, is_product=is_product)
9587

9688
try:
9789
return node.load(is_product=is_product)
@@ -136,7 +128,7 @@ def create_kwargs_for_task(task: PTask, *, remote: bool) -> dict[str, PyTree[Any
136128

137129
def _parse_future_exception(
138130
exc: BaseException | None,
139-
) -> tuple[type[BaseException], BaseException, TracebackType] | None:
131+
) -> tuple[type[BaseException], BaseException, TracebackType | None] | None:
140132
"""Parse a future exception into the format of ``sys.exc_info``."""
141133
return None if exc is None else (type(exc), exc, exc.__traceback__)
142134

@@ -156,5 +148,5 @@ def get_module(func: Callable[..., Any], path: Path | None) -> ModuleType:
156148
func = func.func
157149

158150
if path:
159-
return inspect.getmodule(func, path.as_posix())
160-
return inspect.getmodule(func)
151+
return inspect.getmodule(func, path.as_posix()) # type: ignore[return-value]
152+
return inspect.getmodule(func) # type: ignore[return-value]

0 commit comments

Comments
 (0)