|
| 1 | +import ast |
1 | 2 | import contextlib
|
| 3 | +from itertools import product |
2 | 4 | import os
|
3 | 5 | import sys
|
| 6 | +from tempfile import NamedTemporaryFile |
| 7 | +import time |
4 | 8 |
|
| 9 | +from _pytest.runner import CallInfo |
| 10 | +from _pytest.runner import TestReport |
5 | 11 | import pytest
|
| 12 | +from six import PY2 |
6 | 13 |
|
7 | 14 | from tests.utils import DummyTracer
|
8 | 15 | from tests.utils import TracerSpanContainer
|
@@ -89,3 +96,143 @@ def _snapshot(**kwargs):
|
89 | 96 | yield snapshot
|
90 | 97 |
|
91 | 98 | return _snapshot
|
| 99 | + |
| 100 | + |
| 101 | +# DEV: The dump_code_to_file function is adapted from the compile function in |
| 102 | +# the py_compile module of the Python standard library. It generates .pyc files |
| 103 | +# with the right format. |
| 104 | +if PY2: |
| 105 | + import marshal |
| 106 | + from py_compile import MAGIC |
| 107 | + from py_compile import wr_long |
| 108 | + |
| 109 | + from _pytest._code.code import ExceptionInfo |
| 110 | + |
| 111 | + def dump_code_to_file(code, file): |
| 112 | + file.write(MAGIC) |
| 113 | + wr_long(file, long(time.time())) # noqa |
| 114 | + marshal.dump(code, file) |
| 115 | + file.flush() |
| 116 | + |
| 117 | + |
| 118 | +else: |
| 119 | + import importlib |
| 120 | + |
| 121 | + code_to_pyc = getattr( |
| 122 | + importlib._bootstrap_external, "_code_to_bytecode" if sys.version_info < (3, 7) else "_code_to_timestamp_pyc" |
| 123 | + ) |
| 124 | + |
| 125 | + def dump_code_to_file(code, file): |
| 126 | + file.write(code_to_pyc(code, time.time(), len(code.co_code))) |
| 127 | + file.flush() |
| 128 | + |
| 129 | + |
| 130 | +def unwind_params(params): |
| 131 | + if params is None: |
| 132 | + yield None |
| 133 | + return |
| 134 | + |
| 135 | + for _ in product(*(((k, v) for v in vs) for k, vs in params.items())): |
| 136 | + yield dict(_) |
| 137 | + |
| 138 | + |
| 139 | +class FunctionDefFinder(ast.NodeVisitor): |
| 140 | + def __init__(self, func_name): |
| 141 | + super(FunctionDefFinder, self).__init__() |
| 142 | + self.func_name = func_name |
| 143 | + self._body = None |
| 144 | + |
| 145 | + def generic_visit(self, node): |
| 146 | + return self._body or super(FunctionDefFinder, self).generic_visit(node) |
| 147 | + |
| 148 | + def visit_FunctionDef(self, node): |
| 149 | + if node.name == self.func_name: |
| 150 | + self._body = node.body |
| 151 | + |
| 152 | + def find(self, file): |
| 153 | + with open(file) as f: |
| 154 | + t = ast.parse(f.read()) |
| 155 | + self.visit(t) |
| 156 | + t.body = self._body |
| 157 | + return t |
| 158 | + |
| 159 | + |
| 160 | +def run_function_from_file(item, params=None): |
| 161 | + file, _, func = item.location |
| 162 | + marker = item.get_closest_marker("subprocess") |
| 163 | + |
| 164 | + file_index = 1 |
| 165 | + args = marker.kwargs.get("args", []) |
| 166 | + args.insert(0, None) |
| 167 | + args.insert(0, sys.executable) |
| 168 | + if marker.kwargs.get("ddtrace_run", False): |
| 169 | + file_index += 1 |
| 170 | + args.insert(0, "ddtrace-run") |
| 171 | + |
| 172 | + env = os.environ.copy() |
| 173 | + env.update(marker.kwargs.get("env", {})) |
| 174 | + if params is not None: |
| 175 | + env.update(params) |
| 176 | + |
| 177 | + expected_status = marker.kwargs.get("status", 0) |
| 178 | + |
| 179 | + expected_out = marker.kwargs.get("out", "") |
| 180 | + if expected_out is not None: |
| 181 | + expected_out = expected_out.encode("utf-8") |
| 182 | + |
| 183 | + expected_err = marker.kwargs.get("err", "") |
| 184 | + if expected_err is not None: |
| 185 | + expected_err = expected_err.encode("utf-8") |
| 186 | + |
| 187 | + with NamedTemporaryFile(mode="wb", suffix=".pyc") as fp: |
| 188 | + dump_code_to_file(compile(FunctionDefFinder(func).find(file), file, "exec"), fp.file) |
| 189 | + |
| 190 | + start = time.time() |
| 191 | + args[file_index] = fp.name |
| 192 | + out, err, status, _ = call_program(*args, env=env) |
| 193 | + end = time.time() |
| 194 | + excinfo = None |
| 195 | + |
| 196 | + if status != expected_status: |
| 197 | + excinfo = AssertionError( |
| 198 | + "Expected status %s, got %s.\n=== Captured STDERR ===\n%s=== End of captured STDERR ===" |
| 199 | + % (expected_status, status, err.decode("utf-8")) |
| 200 | + ) |
| 201 | + elif expected_out is not None and out != expected_out: |
| 202 | + excinfo = AssertionError("STDOUT: Expected [%s] got [%s]" % (expected_out, out)) |
| 203 | + elif expected_err is not None and err != expected_err: |
| 204 | + excinfo = AssertionError("STDERR: Expected [%s] got [%s]" % (expected_err, err)) |
| 205 | + |
| 206 | + if PY2 and excinfo is not None: |
| 207 | + try: |
| 208 | + raise excinfo |
| 209 | + except Exception: |
| 210 | + excinfo = ExceptionInfo(sys.exc_info()) |
| 211 | + |
| 212 | + call_info_args = dict(result=None, excinfo=excinfo, start=start, stop=end, when="call") |
| 213 | + if not PY2: |
| 214 | + call_info_args["duration"] = end - start |
| 215 | + |
| 216 | + return TestReport.from_item_and_call(item, CallInfo(**call_info_args)) |
| 217 | + |
| 218 | + |
| 219 | +@pytest.hookimpl(tryfirst=True) |
| 220 | +def pytest_runtest_protocol(item): |
| 221 | + marker = item.get_closest_marker("subprocess") |
| 222 | + if marker: |
| 223 | + params = marker.kwargs.get("parametrize", None) |
| 224 | + ihook = item.ihook |
| 225 | + base_name = item.nodeid |
| 226 | + |
| 227 | + for ps in unwind_params(params): |
| 228 | + nodeid = (base_name + str(ps)) if ps is not None else base_name |
| 229 | + |
| 230 | + ihook.pytest_runtest_logstart(nodeid=nodeid, location=item.location) |
| 231 | + |
| 232 | + report = run_function_from_file(item, ps) |
| 233 | + report.nodeid = nodeid |
| 234 | + ihook.pytest_runtest_logreport(report=report) |
| 235 | + |
| 236 | + ihook.pytest_runtest_logfinish(nodeid=nodeid, location=item.location) |
| 237 | + |
| 238 | + return True |
0 commit comments