Skip to content

Redirect stdout and stderr. #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ releases are available on [PyPI](https://pypi.org/project/pytask-parallel) and
- {pull}`85` simplifies code since loky is a dependency.
- {pull}`88` updates handling `Traceback`.
- {pull}`89` restructures the package.
- {pull}`92` redirects stdout and stderr from processes and loky and shows them in error
reports.

## 0.4.1 - 2024-01-12

Expand Down
19 changes: 16 additions & 3 deletions src/pytask_parallel/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,24 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
future = running_tasks[task_name]

if future.done():
python_nodes, warnings_reports, exc_info = parse_future_result(
future
)
(
python_nodes,
warnings_reports,
exc_info,
captured_stdout,
captured_stderr,
) = parse_future_result(future)
session.warnings.extend(warnings_reports)

if captured_stdout:
task.report_sections.append(
("call", "stdout", captured_stdout)
)
if captured_stderr:
task.report_sections.append(
("call", "stderr", captured_stderr)
)

if exc_info is not None:
task = session.dag.nodes[task_name]["task"]
newly_collected_reports.append(
Expand Down
27 changes: 25 additions & 2 deletions src/pytask_parallel/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import inspect
import sys
import warnings
from contextlib import redirect_stderr
from contextlib import redirect_stdout
from functools import partial
from io import StringIO
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
Expand Down Expand Up @@ -98,6 +101,8 @@ def _execute_task( # noqa: PLR0913
PyTree[PythonNode | None],
list[WarningReport],
tuple[type[BaseException], BaseException, str] | None,
str,
str,
]:
"""Unserialize and execute task.

Expand All @@ -111,8 +116,13 @@ def _execute_task( # noqa: PLR0913
# Patch set_trace and breakpoint to show a better error message.
_patch_set_trace_and_breakpoint()

captured_stdout_buffer = StringIO()
captured_stderr_buffer = StringIO()

# Catch warnings and store them in a list.
with warnings.catch_warnings(record=True) as log:
with warnings.catch_warnings(record=True) as log, redirect_stdout(
captured_stdout_buffer
), redirect_stderr(captured_stderr_buffer):
# Apply global filterwarnings.
for arg in session_filterwarnings:
warnings.filterwarnings(*parse_warning_filter(arg, escape=False))
Expand Down Expand Up @@ -146,12 +156,25 @@ def _execute_task( # noqa: PLR0913
)
)

captured_stdout_buffer.seek(0)
captured_stderr_buffer.seek(0)
captured_stdout = captured_stdout_buffer.read()
captured_stderr = captured_stderr_buffer.read()
captured_stdout_buffer.close()
captured_stderr_buffer.close()

# Collect all PythonNodes that are products to pass values back to the main process.
python_nodes = tree_map(
lambda x: x if isinstance(x, PythonNode) else None, task.produces
)

return python_nodes, warning_reports, processed_exc_info
return (
python_nodes,
warning_reports,
processed_exc_info,
captured_stdout,
captured_stderr,
)


def _process_exception(
Expand Down
8 changes: 6 additions & 2 deletions src/pytask_parallel/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[Any]:
def _mock_processes_for_threads(
task: PTask, **kwargs: Any
) -> tuple[
None, list[Any], tuple[type[BaseException], BaseException, TracebackType] | None
None,
list[Any],
tuple[type[BaseException], BaseException, TracebackType] | None,
str,
str,
]:
"""Mock execution function such that it returns the same as for processes.

Expand All @@ -52,4 +56,4 @@ def _mock_processes_for_threads(
else:
handle_task_function_return(task, out)
exc_info = None
return None, [], exc_info
return None, [], exc_info, "", ""
8 changes: 5 additions & 3 deletions src/pytask_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,22 @@ def parse_future_result(
dict[str, PyTree[PythonNode | None]] | None,
list[WarningReport],
tuple[type[BaseException], BaseException, TracebackType] | None,
str,
str,
]:
"""Parse the result of a future."""
# An exception was raised before the task was executed.
future_exception = future.exception()
if future_exception is not None:
exc_info = _parse_future_exception(future_exception)
return None, [], exc_info
return None, [], exc_info, "", ""

out = future.result()
if isinstance(out, tuple) and len(out) == 3: # noqa: PLR2004
if isinstance(out, tuple) and len(out) == 5: # noqa: PLR2004
return out

if out is None:
return None, [], None
return None, [], None, "", ""

# What to do when the output does not match?
msg = (
Expand Down
58 changes: 58 additions & 0 deletions tests/test_capture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import textwrap

import pytest
from pytask import ExitCode
from pytask import cli
from pytask_parallel import ParallelBackend


@pytest.mark.end_to_end()
@pytest.mark.parametrize(
"parallel_backend", [ParallelBackend.PROCESSES, ParallelBackend.LOKY]
)
@pytest.mark.parametrize("show_capture", ["no", "stdout", "stderr", "all"])
def test_show_capture(tmp_path, runner, parallel_backend, show_capture):
source = """
import sys

def task_show_capture():
sys.stdout.write("xxxx")
sys.stderr.write("zzzz")
raise Exception
"""
tmp_path.joinpath("task_show_capture.py").write_text(textwrap.dedent(source))

cmd_arg = "-s" if show_capture == "s" else f"--show-capture={show_capture}"
result = runner.invoke(
cli,
[
tmp_path.as_posix(),
cmd_arg,
"--parallel-backend",
parallel_backend,
"-n",
"2",
],
)

assert result.exit_code == ExitCode.FAILED

if show_capture in ("no", "s"):
assert "Captured" not in result.output
elif show_capture == "stdout":
assert "Captured stdout" in result.output
assert "xxxx" in result.output
assert "Captured stderr" not in result.output
# assert "zzzz" not in result.output
elif show_capture == "stderr":
assert "Captured stdout" not in result.output
# assert "xxxx" not in result.output
assert "Captured stderr" in result.output
assert "zzzz" in result.output
elif show_capture == "all":
assert "Captured stdout" in result.output
assert "xxxx" in result.output
assert "Captured stderr" in result.output
assert "zzzz" in result.output
else: # pragma: no cover
raise NotImplementedError