Skip to content

Commit cdb28a6

Browse files
committed
Implement workaround for twisted 24 or prior
1 parent 774c933 commit cdb28a6

File tree

1 file changed

+110
-19
lines changed

1 file changed

+110
-19
lines changed

src/_pytest/unittest.py

Lines changed: 110 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from collections.abc import Callable
77
from collections.abc import Generator
88
from collections.abc import Iterable
9+
from collections.abc import Iterator
10+
from enum import auto
11+
from enum import Enum
912
import inspect
1013
import sys
1114
import traceback
@@ -17,6 +20,7 @@
1720
from _pytest.compat import is_async_function
1821
from _pytest.config import hookimpl
1922
from _pytest.fixtures import FixtureRequest
23+
from _pytest.monkeypatch import MonkeyPatch
2024
from _pytest.nodes import Collector
2125
from _pytest.nodes import Item
2226
from _pytest.outcomes import exit
@@ -383,22 +387,91 @@ def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> None:
383387
call.excinfo = call2.excinfo
384388

385389

390+
def _is_skipped(obj) -> bool:
391+
"""Return True if the given object has been marked with @unittest.skip."""
392+
return bool(getattr(obj, "__unittest_skip__", False))
393+
394+
386395
def pytest_configure() -> None:
387396
"""Register the TestCaseFunction class as an IReporter if twisted.trial is available."""
388-
if _is_twisted_trial_available():
397+
if _get_twisted_version() is not TwistedVersion.NotInstalled:
389398
from twisted.trial.itrial import IReporter
390399
from zope.interface import classImplements
391400

392401
classImplements(TestCaseFunction, IReporter)
393402

394403

395-
def _is_skipped(obj) -> bool:
396-
"""Return True if the given object has been marked with @unittest.skip."""
397-
return bool(getattr(obj, "__unittest_skip__", False))
404+
class TwistedVersion(Enum):
405+
"""
406+
The Twisted version installed in the environment.
407+
408+
We have different workarounds in place for different versions of Twisted.
409+
"""
410+
411+
# Twisted version 24 or prior.
412+
Version24 = auto()
413+
# Twisted version 25 or later.
414+
Version25 = auto()
415+
# Twisted version is not available.
416+
NotInstalled = auto()
398417

399418

400-
def _is_twisted_trial_available() -> bool:
401-
return "twisted.trial.unittest" in sys.modules
419+
def _get_twisted_version() -> TwistedVersion:
420+
# Check `"twisted.trial.unittest` is in sys.modules to ensure that Twisted is installed
421+
# being used to run the test suite.
422+
if "twisted.trial.unittest" not in sys.modules:
423+
return TwistedVersion.NotInstalled
424+
425+
import importlib.metadata
426+
427+
import packaging.version
428+
429+
version_str = importlib.metadata.version("twisted")
430+
version = packaging.version.parse(version_str)
431+
if version.major <= 24:
432+
return TwistedVersion.Version24
433+
else:
434+
return TwistedVersion.Version25
435+
436+
437+
# Name of the attribute in `twisted.python.Failure` instances that stores
438+
# the `sys.exc_info()` tuple.
439+
# See twisted.trial support in `pytest_runtest_protocol`.
440+
TWISTED_RAW_EXCINFO_ATTR = "_twisted_raw_excinfo"
441+
442+
443+
@hookimpl(wrapper=True)
444+
def pytest_runtest_protocol(item: Item) -> Iterator[None]:
445+
if _get_twisted_version() is TwistedVersion.Version24:
446+
import twisted.python.failure as ut
447+
448+
# Monkeypatch `Failure.__init__` to store the raw exception info.
449+
original__init__ = ut.Failure.__init__
450+
451+
def store_raw_exception_info(
452+
self, exc_value=None, exc_type=None, exc_tb=None, captureVars=None
453+
):
454+
if exc_value is None:
455+
raw_exc_info = sys.exc_info()
456+
else:
457+
if exc_type is None:
458+
exc_type = type(exc_value)
459+
if exc_tb is None:
460+
exc_tb = sys.exc_info()[2]
461+
raw_exc_info = (exc_type, exc_value, exc_tb)
462+
setattr(self, TWISTED_RAW_EXCINFO_ATTR, tuple(raw_exc_info))
463+
try:
464+
original__init__(
465+
self, exc_value, exc_type, exc_tb, captureVars=captureVars
466+
)
467+
except TypeError:
468+
original__init__(self, exc_value, exc_type, exc_tb)
469+
470+
with MonkeyPatch.context() as patcher:
471+
patcher.setattr(ut.Failure, "__init__", store_raw_exception_info)
472+
return (yield)
473+
else:
474+
return (yield)
402475

403476

404477
def _handle_twisted_exc_info(
@@ -409,16 +482,34 @@ def _handle_twisted_exc_info(
409482
Therefore, if `rawexcinfo` is a `Failure` instance, convert it into the equivalent `sys.exc_info()` tuple
410483
as expected by pytest.
411484
"""
412-
if isinstance(rawexcinfo, BaseException) and _is_twisted_trial_available():
413-
import twisted.python.failure
414-
415-
if isinstance(rawexcinfo, twisted.python.failure.Failure):
416-
tb = rawexcinfo.__traceback__
417-
if tb is None:
418-
tb = sys.exc_info()[2]
419-
return type(rawexcinfo.value), rawexcinfo.value, tb
420-
421-
# Unfortunately, because we cannot import `twisted.python.failure` at the top of the file
422-
# and use it in the signature, we need to use `type:ignore` here because we cannot narrow
423-
# the type properly in the `if` statement above.
424-
return rawexcinfo # type:ignore[return-value]
485+
twisted_version = _get_twisted_version()
486+
if twisted_version is TwistedVersion.NotInstalled:
487+
# Unfortunately, because we cannot import `twisted.python.failure` at the top of the file
488+
# and use it in the signature, we need to use `type:ignore` here because we cannot narrow
489+
# the type properly in the `if` statement above.
490+
return rawexcinfo # type:ignore[return-value]
491+
elif twisted_version is TwistedVersion.Version24:
492+
# Twisted calls addError() passing its own classes (like `twisted.python.Failure`), which violates
493+
# the `addError()` signature, so we extract the original `sys.exc_info()` tuple which is stored
494+
# in the object.
495+
if hasattr(rawexcinfo, TWISTED_RAW_EXCINFO_ATTR):
496+
saved_exc_info = getattr(rawexcinfo, TWISTED_RAW_EXCINFO_ATTR)
497+
# Delete the attribute from the original object to avoid leaks.
498+
delattr(rawexcinfo, TWISTED_RAW_EXCINFO_ATTR)
499+
return saved_exc_info # type:ignore[no-any-return]
500+
return rawexcinfo # type:ignore[return-value]
501+
elif twisted_version is TwistedVersion.Version25:
502+
if isinstance(rawexcinfo, BaseException):
503+
import twisted.python.failure
504+
505+
if isinstance(rawexcinfo, twisted.python.failure.Failure):
506+
tb = rawexcinfo.__traceback__
507+
if tb is None:
508+
tb = sys.exc_info()[2]
509+
return type(rawexcinfo.value), rawexcinfo.value, tb
510+
511+
return rawexcinfo # type:ignore[return-value]
512+
else:
513+
# Ideally we would use assert_never() here, but it is not available in all Python versions
514+
# we support, plus we do not require `type_extensions` currently.
515+
assert False, f"Unexpected Twisted version: {twisted_version}"

0 commit comments

Comments
 (0)