diff --git a/src/_pytest/capture.py b/src/_pytest/capture.py index 0cd3ce60427..63bbea46556 100644 --- a/src/_pytest/capture.py +++ b/src/_pytest/capture.py @@ -2,6 +2,7 @@ per-test stdout/stderr capturing mechanism. """ +import atexit import collections import contextlib import io @@ -9,6 +10,9 @@ import sys from io import UnsupportedOperation from tempfile import TemporaryFile +from typing import Callable +from typing import List +from typing import Optional import pytest from _pytest.compat import CaptureIO @@ -77,19 +81,29 @@ def __init__(self, method): self._method = method self._global_capturing = None self._current_item = None + self._atexit_funcs = [] # type: List[Callable] + atexit.register(self._atexit_run) + self._tmpfiles = {} def __repr__(self): return "".format( self._method, self._global_capturing, self._current_item ) + def _atexit_register(self, func): + self._atexit_funcs.append(func) + + def _atexit_run(self): + for func in self._atexit_funcs: + func() + def _getcapture(self, method): if method == "fd": - return MultiCapture(out=True, err=True, Capture=FDCapture) + return MultiCapture(out=True, err=True, Capture=FDCapture, capman=self) elif method == "sys": - return MultiCapture(out=True, err=True, Capture=SysCapture) + return MultiCapture(out=True, err=True, Capture=SysCapture, capman=self) elif method == "no": - return MultiCapture(out=False, err=False, in_=False) + return MultiCapture(out=False, err=False, in_=False, capman=self) raise ValueError("unknown capturing method: %r" % method) # pragma: no cover def is_capturing(self): @@ -450,13 +464,20 @@ class MultiCapture: out = err = in_ = None _state = None - def __init__(self, out=True, err=True, in_=True, Capture=None): + def __init__( + self, + out=True, + err=True, + in_=True, + Capture=None, + capman: Optional[CaptureManager] = None, + ): if in_: - self.in_ = Capture(0) + self.in_ = Capture(0, capman=capman) if out: - self.out = Capture(1) + self.out = Capture(1, capman=capman) if err: - self.err = Capture(2) + self.err = Capture(2, capman=capman) def __repr__(self): return "".format( @@ -539,8 +560,9 @@ class FDCaptureBinary: EMPTY_BUFFER = b"" _state = None - def __init__(self, targetfd, tmpfile=None): + def __init__(self, targetfd, tmpfile=None, capman: Optional[CaptureManager] = None): self.targetfd = targetfd + self._capman = capman try: self.targetfd_save = os.dup(self.targetfd) except OSError: @@ -551,15 +573,33 @@ def __init__(self, targetfd, tmpfile=None): self.done = self._done if targetfd == 0: assert not tmpfile, "cannot set tmpfile with stdin" - tmpfile = open(os.devnull, "r") + if capman: + try: + tmpfile = capman._tmpfiles[0] + assert not tmpfile.closed + except KeyError: + tmpfile = open(os.devnull, "r") + capman._tmpfiles[targetfd] = tmpfile + else: + tmpfile = open(os.devnull, "r") self.syscapture = SysCapture(targetfd) else: if tmpfile is None: - f = TemporaryFile() - with f: - tmpfile = safe_text_dupfile(f, mode="wb+") + if capman: + try: + tmpfile = capman._tmpfiles[targetfd] + assert not tmpfile.closed + except KeyError: + f = TemporaryFile() + with f: + tmpfile = safe_text_dupfile(f, mode="wb+") + capman._tmpfiles[targetfd] = tmpfile + else: + f = TemporaryFile() + with f: + tmpfile = safe_text_dupfile(f, mode="wb+") if targetfd in patchsysdict: - self.syscapture = SysCapture(targetfd, tmpfile) + self.syscapture = SysCapture(targetfd, tmpfile, capman) else: self.syscapture = NoCapture() self.tmpfile = tmpfile @@ -594,7 +634,12 @@ def _done(self): os.dup2(targetfd_save, self.targetfd) os.close(targetfd_save) self.syscapture.done() - self.tmpfile.close() + if self._capman: + # Redirect any remaining output. + os.dup2(self.targetfd, self.tmpfile_fd) + self._capman._atexit_register(self.tmpfile.close) + else: + self.tmpfile.close() self._state = "done" def suspend(self): @@ -636,8 +681,9 @@ class SysCapture: EMPTY_BUFFER = str() _state = None - def __init__(self, fd, tmpfile=None): + def __init__(self, fd, tmpfile=None, capman: Optional[CaptureManager] = None): name = patchsysdict[fd] + self._capman = capman self._old = getattr(sys, name) self.name = name if tmpfile is None: @@ -665,7 +711,10 @@ def snap(self): def done(self): setattr(sys, self.name, self._old) del self._old - self.tmpfile.close() + if self._capman: + self._capman._atexit_register(self.tmpfile.close) + else: + self.tmpfile.close() self._state = "done" def suspend(self): diff --git a/testing/test_capture.py b/testing/test_capture.py index 94af3aef75c..c6837f96c3a 100644 --- a/testing/test_capture.py +++ b/testing/test_capture.py @@ -1490,3 +1490,47 @@ def test_fails(): result_with_capture.stdout.fnmatch_lines( ["E * TypeError: write() argument must be str, not bytes"] ) + + +def test_logging_in_atexit(testdir): + p = testdir.makepyfile( + """ + import atexit + import logging + import sys + + cur_stdout = sys.stdout + LOGGER = logging.getLogger(__name__) + + def test_fail(): + assert 0 + + def _atexit(): + print("test-print in atexit", cur_stdout) + LOGGER.error("test-log in atexit") + + print() + print("test-register") + print() + atexit.register(_atexit) + logging.basicConfig() + + LOGGER.error("log_setup_not_shown_from_collection") + + print(sys.stderr, id(sys.stderr)) + """ + ) + result = testdir.runpytest_subprocess(str(p)) + result.stdout.fnmatch_lines( + [ + "*= 1 failed in *", + "test-print in atexit <_pytest.capture.EncodedFile object *", + ] + ) + assert result.stderr.lines == ["ERROR:test_logging_in_atexit:test-log in atexit"] + assert result.ret == 1 + + output = str(result.stdout) + str(result.stderr) + assert "test-register" not in output + assert "*- Captured stderr call -*" not in output + assert "log_setup_not_shown_from_collection" not in output