Skip to content

Commit 9df24e8

Browse files
authored
Capture warnings. (#44)
1 parent deba157 commit 9df24e8

File tree

5 files changed

+173
-45
lines changed

5 files changed

+173
-45
lines changed

CHANGES.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@ chronological order. Releases follow [semantic versioning](https://semver.org/)
55
releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
66
[Anaconda.org](https://anaconda.org/conda-forge/pytask-parallel).
77

8-
## 0.2.0 - 2022-xx-xx
8+
## 0.2.1 - 2022-08-xx
9+
10+
- {pull}`43` adds docformatter.
11+
- {pull}`44` allows to capture warnings from subprocesses. Fixes {issue}`41`.
12+
13+
## 0.2.0 - 2022-04-15
914

1015
- {pull}`31` adds types to the package.
1116
- {pull}`36` adds a test for <https://github.com/pytask-dev/pytask/issues/216>.

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,21 @@ n_workers = 1
6868
parallel_backend = "loky" # or processes or threads
6969
```
7070

71-
## Warning
71+
## Some implementation details
72+
73+
### Parallelization and Debugging
7274

7375
It is not possible to combine parallelization with debugging. That is why `--pdb` or
7476
`--trace` deactivate parallelization.
7577

7678
If you parallelize the execution of your tasks using two or more workers, do not use
7779
`breakpoint()` or `import pdb; pdb.set_trace()` since both will cause exceptions.
7880

81+
### Threads and warnings
82+
83+
Capturing warnings is not thread-safe. Therefore, warnings cannot be captured reliably
84+
when tasks are parallelized with `--parallel-backend threads`.
85+
7986
## Changes
8087

8188
Consult the [release notes](CHANGES.md) to find out about what is new.

src/pytask_parallel/execute.py

Lines changed: 118 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,36 @@
44
import inspect
55
import sys
66
import time
7+
import warnings
78
from concurrent.futures import Future
89
from types import TracebackType
910
from typing import Any
11+
from typing import Callable
1012

1113
import cloudpickle
1214
from pybaum.tree_util import tree_map
1315
from pytask import console
1416
from pytask import ExecutionReport
17+
from pytask import get_marks
1518
from pytask import hookimpl
19+
from pytask import Mark
1620
from pytask import remove_internal_traceback_frames_from_exc_info
1721
from pytask import Session
1822
from pytask import Task
1923
from pytask_parallel.backends import PARALLEL_BACKENDS
2024
from rich.console import ConsoleOptions
2125
from rich.traceback import Traceback
2226

27+
# Can be removed if pinned to pytask >= 0.2.6.
28+
try:
29+
from pytask import parse_warning_filter
30+
from pytask import warning_record_to_str
31+
from pytask import WarningReport
32+
except ImportError:
33+
from _pytask.warnings import parse_warning_filter
34+
from _pytask.warnings import warning_record_to_str
35+
from _pytask.warnings_utils import WarningReport
36+
2337

2438
@hookimpl
2539
def pytask_post_parse(config: dict[str, Any]) -> None:
@@ -85,42 +99,38 @@ def pytask_execute_build(session: Session) -> bool | None:
8599

86100
for task_name in list(running_tasks):
87101
future = running_tasks[task_name]
88-
if future.done() and (
89-
future.exception() is not None
90-
or future.result() is not None
91-
):
92-
task = session.dag.nodes[task_name]["task"]
93-
if future.exception() is not None:
94-
exception = future.exception()
95-
exc_info = (
96-
type(exception),
97-
exception,
98-
exception.__traceback__,
99-
)
100-
else:
101-
exc_info = future.result()
102-
103-
newly_collected_reports.append(
104-
ExecutionReport.from_task_and_exception(task, exc_info)
102+
if future.done():
103+
warning_reports, task_exception = future.result()
104+
session.warnings.extend(warning_reports)
105+
exc_info = (
106+
_parse_future_exception(future.exception())
107+
or task_exception
105108
)
106-
running_tasks.pop(task_name)
107-
session.scheduler.done(task_name)
108-
elif future.done() and future.exception() is None:
109-
task = session.dag.nodes[task_name]["task"]
110-
try:
111-
session.hook.pytask_execute_task_teardown(
112-
session=session, task=task
113-
)
114-
except Exception:
115-
report = ExecutionReport.from_task_and_exception(
116-
task, sys.exc_info()
109+
if exc_info is not None:
110+
task = session.dag.nodes[task_name]["task"]
111+
newly_collected_reports.append(
112+
ExecutionReport.from_task_and_exception(
113+
task, exc_info
114+
)
117115
)
116+
running_tasks.pop(task_name)
117+
session.scheduler.done(task_name)
118118
else:
119-
report = ExecutionReport.from_task(task)
120-
121-
running_tasks.pop(task_name)
122-
newly_collected_reports.append(report)
123-
session.scheduler.done(task_name)
119+
task = session.dag.nodes[task_name]["task"]
120+
try:
121+
session.hook.pytask_execute_task_teardown(
122+
session=session, task=task
123+
)
124+
except Exception:
125+
report = ExecutionReport.from_task_and_exception(
126+
task, sys.exc_info()
127+
)
128+
else:
129+
report = ExecutionReport.from_task(task)
130+
131+
running_tasks.pop(task_name)
132+
newly_collected_reports.append(report)
133+
session.scheduler.done(task_name)
124134
else:
125135
pass
126136

@@ -144,6 +154,17 @@ def pytask_execute_build(session: Session) -> bool | None:
144154
return None
145155

146156

157+
def _parse_future_exception(
158+
exception: BaseException | None,
159+
) -> tuple[type[BaseException], BaseException, TracebackType] | None:
160+
"""Parse a future exception."""
161+
return (
162+
None
163+
if exception is None
164+
else (type(exception), exception, exception.__traceback__)
165+
)
166+
167+
147168
class ProcessesNameSpace:
148169
"""The name space for hooks related to processes."""
149170

@@ -167,6 +188,9 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
167188
bytes_kwargs=bytes_kwargs,
168189
show_locals=session.config["show_locals"],
169190
console_options=console.options,
191+
session_filterwarnings=session.config["filterwarnings"],
192+
task_filterwarnings=get_marks(task, "filterwarnings"),
193+
task_short_name=task.short_name,
170194
)
171195
return None
172196

@@ -176,7 +200,10 @@ def _unserialize_and_execute_task(
176200
bytes_kwargs: bytes,
177201
show_locals: bool,
178202
console_options: ConsoleOptions,
179-
) -> tuple[type[BaseException], BaseException, str] | None:
203+
session_filterwarnings: tuple[str, ...],
204+
task_filterwarnings: tuple[Mark, ...],
205+
task_short_name: str,
206+
) -> tuple[list[WarningReport], tuple[type[BaseException], BaseException, str] | None]:
180207
"""Unserialize and execute task.
181208
182209
This function receives bytes and unpickles them to a task which is them execute in a
@@ -188,13 +215,40 @@ def _unserialize_and_execute_task(
188215
task = cloudpickle.loads(bytes_function)
189216
kwargs = cloudpickle.loads(bytes_kwargs)
190217

191-
try:
192-
task.execute(**kwargs)
193-
except Exception:
194-
exc_info = sys.exc_info()
195-
processed_exc_info = _process_exception(exc_info, show_locals, console_options)
196-
return processed_exc_info
197-
return None
218+
with warnings.catch_warnings(record=True) as log:
219+
# mypy can't infer that record=True means log is not None; help it.
220+
assert log is not None
221+
222+
for arg in session_filterwarnings:
223+
warnings.filterwarnings(*parse_warning_filter(arg, escape=False))
224+
225+
# apply filters from "filterwarnings" marks
226+
for mark in task_filterwarnings:
227+
for arg in mark.args:
228+
warnings.filterwarnings(*parse_warning_filter(arg, escape=False))
229+
230+
try:
231+
task.execute(**kwargs)
232+
except Exception:
233+
exc_info = sys.exc_info()
234+
processed_exc_info = _process_exception(
235+
exc_info, show_locals, console_options
236+
)
237+
else:
238+
processed_exc_info = None
239+
240+
warning_reports = []
241+
for warning_message in log:
242+
fs_location = warning_message.filename, warning_message.lineno
243+
warning_reports.append(
244+
WarningReport(
245+
message=warning_record_to_str(warning_message),
246+
fs_location=fs_location,
247+
id_=task_short_name,
248+
)
249+
)
250+
251+
return warning_reports, processed_exc_info
198252

199253

200254
def _process_exception(
@@ -224,11 +278,33 @@ def pytask_execute_task(session: Session, task: Task) -> Future[Any] | None:
224278
"""
225279
if session.config["n_workers"] > 1:
226280
kwargs = _create_kwargs_for_task(task)
227-
return session.executor.submit(task.execute, **kwargs)
281+
return session.executor.submit(
282+
_mock_processes_for_threads, func=task.execute, **kwargs
283+
)
228284
else:
229285
return None
230286

231287

288+
def _mock_processes_for_threads(
289+
func: Callable[..., Any], **kwargs: Any
290+
) -> tuple[list[Any], tuple[type[BaseException], BaseException, TracebackType] | None]:
291+
"""Mock execution function such that it returns the same as for processes.
292+
293+
The function for processes returns ``warning_reports`` and an ``exception``. With
294+
threads, these object are collected by the main and not the subprocess. So, we just
295+
return placeholders.
296+
297+
"""
298+
__tracebackhide__ = True
299+
try:
300+
func(**kwargs)
301+
except Exception:
302+
exc_info = sys.exc_info()
303+
else:
304+
exc_info = None
305+
return [], exc_info
306+
307+
232308
def _create_kwargs_for_task(task: Task) -> dict[Any, Any]:
233309
"""Create kwargs for task function."""
234310
kwargs = {**task.kwargs}

tests/test_execute.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def myfunc():
119119
"n_workers": 2,
120120
"parallel_backend": parallel_backend,
121121
"show_locals": False,
122+
"filterwarnings": [],
122123
}
123124

124125
with PARALLEL_BACKENDS[parallel_backend](
@@ -135,7 +136,9 @@ def myfunc():
135136
future = backend_name_space.pytask_execute_task(session, task)
136137
executor.shutdown()
137138

138-
assert future.result() is None
139+
warning_reports, exception = future.result()
140+
assert warning_reports == []
141+
assert exception is None
139142

140143

141144
@pytest.mark.end_to_end
@@ -288,3 +291,37 @@ def task_example(produces):
288291
)
289292

290293
assert session.exit_code == ExitCode.OK
294+
295+
296+
@pytest.mark.end_to_end
297+
@pytest.mark.parametrize(
298+
"parallel_backend",
299+
# Capturing warnings is not thread-safe.
300+
[backend for backend in PARALLEL_BACKENDS if backend != "threads"],
301+
)
302+
def test_collect_warnings_from_parallelized_tasks(runner, tmp_path, parallel_backend):
303+
source = """
304+
import pytask
305+
import warnings
306+
307+
for i in range(2):
308+
309+
@pytask.mark.task(id=i, kwargs={"produces": f"{i}.txt"})
310+
def task_example(produces):
311+
warnings.warn("This is a warning.")
312+
produces.touch()
313+
"""
314+
tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source))
315+
316+
result = runner.invoke(
317+
cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend]
318+
)
319+
320+
assert result.exit_code == ExitCode.OK
321+
assert "Warnings" in result.output
322+
assert "This is a warning." in result.output
323+
assert "capture_warnings.html" in result.output
324+
325+
warnings_block = result.output.split("Warnings")[1]
326+
assert "task_example.py::task_example[0]" in warnings_block
327+
assert "task_example.py::task_example[1]" in warnings_block

tox.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ warn-symbols =
3939
pytest.mark.skip = Remove 'skip' flag for tests.
4040

4141
[pytest]
42+
testpaths =
43+
# Do not add src since it messes with the loading of pytask-parallel as a plugin.
44+
tests
4245
addopts = --doctest-modules
4346
filterwarnings =
4447
ignore: the imp module is deprecated in favour of importlib

0 commit comments

Comments
 (0)