|
38 | 38 | SubprocessInfo = namedtuple("SubprocessInfo", ["process", "socket"]) |
39 | 39 |
|
40 | 40 |
|
| 41 | +def _wait_for_signal(sock, expected_signals, timeout=SHORT_TIMEOUT): |
| 42 | + """ |
| 43 | + Wait for expected signal(s) from a socket with proper timeout and EOF handling. |
| 44 | +
|
| 45 | + Args: |
| 46 | + sock: Connected socket to read from |
| 47 | + expected_signals: Single bytes object or list of bytes objects to wait for |
| 48 | + timeout: Socket timeout in seconds |
| 49 | +
|
| 50 | + Returns: |
| 51 | + bytes: Complete accumulated response buffer |
| 52 | +
|
| 53 | + Raises: |
| 54 | + RuntimeError: If connection closed before signal received or timeout |
| 55 | + """ |
| 56 | + if isinstance(expected_signals, bytes): |
| 57 | + expected_signals = [expected_signals] |
| 58 | + |
| 59 | + sock.settimeout(timeout) |
| 60 | + buffer = b"" |
| 61 | + |
| 62 | + while True: |
| 63 | + # Check if all expected signals are in buffer |
| 64 | + if all(sig in buffer for sig in expected_signals): |
| 65 | + return buffer |
| 66 | + |
| 67 | + try: |
| 68 | + chunk = sock.recv(4096) |
| 69 | + if not chunk: |
| 70 | + raise RuntimeError( |
| 71 | + f"Connection closed before receiving expected signals. " |
| 72 | + f"Expected: {expected_signals}, Got: {buffer[-200:]!r}" |
| 73 | + ) |
| 74 | + buffer += chunk |
| 75 | + except socket.timeout: |
| 76 | + raise RuntimeError( |
| 77 | + f"Timeout waiting for signals. " |
| 78 | + f"Expected: {expected_signals}, Got: {buffer[-200:]!r}" |
| 79 | + ) from None |
| 80 | + except OSError as e: |
| 81 | + raise RuntimeError( |
| 82 | + f"Socket error while waiting for signals: {e}. " |
| 83 | + f"Expected: {expected_signals}, Got: {buffer[-200:]!r}" |
| 84 | + ) from None |
| 85 | + |
| 86 | + |
| 87 | +def _cleanup_sockets(*sockets): |
| 88 | + """Safely close multiple sockets, ignoring errors.""" |
| 89 | + for sock in sockets: |
| 90 | + if sock is not None: |
| 91 | + try: |
| 92 | + sock.close() |
| 93 | + except OSError: |
| 94 | + pass |
| 95 | + |
| 96 | + |
| 97 | +def _cleanup_process(proc, timeout=SHORT_TIMEOUT): |
| 98 | + """Terminate a process gracefully, escalating to kill if needed.""" |
| 99 | + if proc.poll() is not None: |
| 100 | + return |
| 101 | + proc.terminate() |
| 102 | + try: |
| 103 | + proc.wait(timeout=timeout) |
| 104 | + return |
| 105 | + except subprocess.TimeoutExpired: |
| 106 | + pass |
| 107 | + proc.kill() |
| 108 | + try: |
| 109 | + proc.wait(timeout=timeout) |
| 110 | + except subprocess.TimeoutExpired: |
| 111 | + pass # Process refuses to die, nothing more we can do |
| 112 | + |
| 113 | + |
41 | 114 | @contextlib.contextmanager |
42 | | -def test_subprocess(script): |
| 115 | +def test_subprocess(script, wait_for_working=False): |
43 | 116 | """Context manager to create a test subprocess with socket synchronization. |
44 | 117 |
|
45 | 118 | Args: |
46 | | - script: Python code to execute in the subprocess |
| 119 | + script: Python code to execute in the subprocess. If wait_for_working |
| 120 | + is True, script should send b"working" after starting work. |
| 121 | + wait_for_working: If True, wait for both "ready" and "working" signals. |
| 122 | + Default False for backward compatibility. |
47 | 123 |
|
48 | 124 | Yields: |
49 | 125 | SubprocessInfo: Named tuple with process and socket objects |
@@ -80,19 +156,18 @@ def test_subprocess(script): |
80 | 156 | # Wait for process to connect and send ready signal |
81 | 157 | client_socket, _ = server_socket.accept() |
82 | 158 | server_socket.close() |
83 | | - response = client_socket.recv(1024) |
84 | | - if response != b"ready": |
85 | | - raise RuntimeError( |
86 | | - f"Unexpected response from subprocess: {response!r}" |
87 | | - ) |
| 159 | + server_socket = None |
| 160 | + |
| 161 | + # Wait for ready signal, and optionally working signal |
| 162 | + if wait_for_working: |
| 163 | + _wait_for_signal(client_socket, [b"ready", b"working"]) |
| 164 | + else: |
| 165 | + _wait_for_signal(client_socket, b"ready") |
88 | 166 |
|
89 | 167 | yield SubprocessInfo(proc, client_socket) |
90 | 168 | finally: |
91 | | - if client_socket is not None: |
92 | | - client_socket.close() |
93 | | - if proc.poll() is None: |
94 | | - proc.kill() |
95 | | - proc.wait() |
| 169 | + _cleanup_sockets(client_socket, server_socket) |
| 170 | + _cleanup_process(proc) |
96 | 171 |
|
97 | 172 |
|
98 | 173 | def close_and_unlink(file): |
|
0 commit comments