Skip to content

Commit 3d28b1f

Browse files
committed
Fix.
1 parent fcb8279 commit 3d28b1f

File tree

4 files changed

+38
-46
lines changed

4 files changed

+38
-46
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,13 @@ repos:
7272
--ignore-missing-imports,
7373
]
7474
additional_dependencies: [
75+
cloudpickle,
76+
optree,
77+
pytask==0.4.0rc2,
78+
rich,
7579
types-attrs,
7680
types-click,
7781
types-setuptools,
78-
pytask==0.4.0rc2,
79-
rich,
80-
optree,
8182
]
8283
pass_filenames: false
8384
- repo: https://github.com/mgedmin/check-manifest

src/pytask_parallel/backends.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99
from typing import Callable
1010

1111
import cloudpickle
12+
from _pytask.path import import_path
1213

1314

1415
def deserialize_and_run_with_cloudpickle(
15-
fn: Callable[..., Any], kwargs: dict[str, Any]
16+
fn: bytes, kwargs: bytes, kwargs_import_path: bytes
1617
) -> Any:
1718
"""Deserialize and execute a function and keyword arguments."""
19+
deserialized_kwargs_import_path = cloudpickle.loads(kwargs_import_path)
20+
import_path(**deserialized_kwargs_import_path)
21+
1822
deserialized_fn = cloudpickle.loads(fn)
1923
deserialized_kwargs = cloudpickle.loads(kwargs)
2024
return deserialized_fn(**deserialized_kwargs)
@@ -31,6 +35,7 @@ def submit( # type: ignore[override]
3135
return super().submit(
3236
deserialize_and_run_with_cloudpickle,
3337
fn=cloudpickle.dumps(fn),
38+
kwargs_import_path=cloudpickle.dumps(kwargs.pop("kwargs_import_path")),
3439
kwargs=cloudpickle.dumps(kwargs),
3540
)
3641

src/pytask_parallel/execute.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
from pytask import Mark
2020
from pytask import parse_warning_filter
2121
from pytask import PTask
22+
from pytask import PTaskWithPath
2223
from pytask import remove_internal_traceback_frames_from_exc_info
2324
from pytask import Session
2425
from pytask import Task
2526
from pytask import warning_record_to_str
2627
from pytask import WarningReport
28+
from pytask.tree_util import PyTree
2729
from pytask.tree_util import tree_leaves
2830
from pytask.tree_util import tree_map
2931
from pytask.tree_util import tree_structure
@@ -175,7 +177,7 @@ class ProcessesNameSpace:
175177

176178
@staticmethod
177179
@hookimpl(tryfirst=True)
178-
def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
180+
def pytask_execute_task(session: Session, task: PTask) -> Future[Any] | None:
179181
"""Execute a task.
180182
181183
Take a task, pickle it and send the bytes over to another process.
@@ -184,6 +186,11 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
184186
if session.config["n_workers"] > 1:
185187
kwargs = _create_kwargs_for_task(task)
186188

189+
if sys.platform == "win32" and isinstance(task, PTaskWithPath):
190+
kwargs_import_path = {"path": task.path, "root": session.config["root"]}
191+
else:
192+
kwargs_import_path = None
193+
187194
return session.config["_parallel_executor"].submit(
188195
_execute_task,
189196
task=task,
@@ -192,6 +199,7 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
192199
console_options=console.options,
193200
session_filterwarnings=session.config["filterwarnings"],
194201
task_filterwarnings=get_marks(task, "filterwarnings"),
202+
kwargs_import_path=kwargs_import_path,
195203
)
196204
return None
197205

@@ -319,7 +327,7 @@ def _mock_processes_for_threads(
319327
return [], exc_info
320328

321329

322-
def _create_kwargs_for_task(task: Task) -> dict[Any, Any]:
330+
def _create_kwargs_for_task(task: PTask) -> dict[str, PyTree[Any]]:
323331
"""Create kwargs for task function."""
324332
parameters = inspect.signature(task.function).parameters
325333

tests/test_execute.py

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -112,46 +112,6 @@ def task_2(produces):
112112
assert end - start < 10
113113

114114

115-
@pytest.mark.integration()
116-
@pytest.mark.parametrize("parallel_backend", _PARALLEL_BACKENDS_PARAMETRIZATION)
117-
def test_pytask_execute_task_w_processes(parallel_backend):
118-
# Local function which cannot be used with multiprocessing.
119-
def myfunc():
120-
return 1
121-
122-
# Verify it cannot be used with multiprocessing because it cannot be pickled.
123-
with pytest.raises(AttributeError):
124-
pickle.dumps(myfunc)
125-
126-
task = Task(base_name="task_example", path=Path(), function=myfunc)
127-
128-
session = Session()
129-
session.config = {
130-
"n_workers": 2,
131-
"parallel_backend": parallel_backend,
132-
"show_locals": False,
133-
"filterwarnings": [],
134-
}
135-
136-
with PARALLEL_BACKENDS[parallel_backend](
137-
max_workers=session.config["n_workers"]
138-
) as executor:
139-
session.config["_parallel_executor"] = executor
140-
141-
backend_name_space = {
142-
ParallelBackendChoices.PROCESSES: ProcessesNameSpace,
143-
ParallelBackendChoices.THREADS: DefaultBackendNameSpace,
144-
ParallelBackendChoices.LOKY: DefaultBackendNameSpace,
145-
}[parallel_backend]
146-
147-
future = backend_name_space.pytask_execute_task(session, task)
148-
executor.shutdown()
149-
150-
warning_reports, exception = future.result()
151-
assert warning_reports == []
152-
assert exception is None
153-
154-
155115
@pytest.mark.end_to_end()
156116
@pytest.mark.parametrize("parallel_backend", _PARALLEL_BACKENDS_PARAMETRIZATION)
157117
def test_stop_execution_when_max_failures_is_reached(tmp_path, parallel_backend):
@@ -318,3 +278,21 @@ def task_example() -> Annotated[str, Path("file.txt")]:
318278
)
319279
assert result.exit_code == ExitCode.OK
320280
assert tmp_path.joinpath("file.txt").exists()
281+
282+
283+
284+
@pytest.mark.end_to_end()
285+
@pytest.mark.parametrize("parallel_backend", _PARALLEL_BACKENDS_PARAMETRIZATION)
286+
def test_task_without_path_that_return(runner, tmp_path, parallel_backend):
287+
source = """
288+
from pathlib import Path
289+
from pytask import task
290+
291+
task_example = task(produces=Path("file.txt"))(lambda *x: "Hello, Darkness, my old friend.")
292+
"""
293+
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
294+
result = runner.invoke(
295+
cli, [tmp_path.as_posix(), "--parallel-backend", parallel_backend]
296+
)
297+
assert result.exit_code == ExitCode.OK
298+
assert tmp_path.joinpath("file.txt").exists()

0 commit comments

Comments
 (0)