|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import os |
| 4 | +import threading |
| 5 | +from collections.abc import Generator |
4 | 6 | from contextlib import AbstractContextManager, nullcontext as does_not_raise |
| 7 | +from io import BytesIO |
5 | 8 | from pathlib import Path |
6 | | -from typing import Any |
| 9 | +from tempfile import SpooledTemporaryFile |
| 10 | +from typing import Any, ClassVar |
| 11 | +from unittest import mock |
7 | 12 |
|
8 | 13 | import pytest |
9 | 14 |
|
10 | 15 | from starlette.applications import Starlette |
11 | 16 | from starlette.datastructures import UploadFile |
12 | | -from starlette.formparsers import MultiPartException, _user_safe_decode |
| 17 | +from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode |
13 | 18 | from starlette.requests import Request |
14 | 19 | from starlette.responses import JSONResponse |
15 | 20 | from starlette.routing import Mount |
@@ -104,6 +109,22 @@ async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None: |
104 | 109 | await response(scope, receive, send) |
105 | 110 |
|
106 | 111 |
|
| 112 | +async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None: |
| 113 | + """Helper app to monitor what thread the app was called on. |
| 114 | +
|
| 115 | + This can later be used to validate thread/event loop operations. |
| 116 | + """ |
| 117 | + request = Request(scope, receive) |
| 118 | + |
| 119 | + # Make sure we parse the form |
| 120 | + await request.form() |
| 121 | + await request.close() |
| 122 | + |
| 123 | + # Send back the current thread id |
| 124 | + response = JSONResponse({"thread_ident": threading.current_thread().ident}) |
| 125 | + await response(scope, receive, send) |
| 126 | + |
| 127 | + |
107 | 128 | def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp: |
108 | 129 | async def app(scope: Scope, receive: Receive, send: Send) -> None: |
109 | 130 | request = Request(scope, receive) |
@@ -303,6 +324,47 @@ def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factor |
303 | 324 | } |
304 | 325 |
|
305 | 326 |
|
| 327 | +class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]): |
| 328 | + """Helper class to track which threads performed the rollover operation. |
| 329 | +
|
| 330 | + This is not threadsafe/multi-test safe. |
| 331 | + """ |
| 332 | + |
| 333 | + rollover_threads: ClassVar[set[int | None]] = set() |
| 334 | + |
| 335 | + def rollover(self) -> None: |
| 336 | + ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident) |
| 337 | + super().rollover() |
| 338 | + |
| 339 | + |
| 340 | +@pytest.fixture |
| 341 | +def mock_spooled_temporary_file() -> Generator[None]: |
| 342 | + try: |
| 343 | + with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile): |
| 344 | + yield |
| 345 | + finally: |
| 346 | + ThreadTrackingSpooledTemporaryFile.rollover_threads.clear() |
| 347 | + |
| 348 | + |
| 349 | +def test_multipart_request_large_file_rollover_in_background_thread( |
| 350 | + mock_spooled_temporary_file: None, test_client_factory: TestClientFactory |
| 351 | +) -> None: |
| 352 | + """Test that Spooled file rollovers happen in background threads.""" |
| 353 | + data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1)) |
| 354 | + |
| 355 | + client = test_client_factory(app_monitor_thread) |
| 356 | + response = client.post("/", files=[("test_large", data)]) |
| 357 | + assert response.status_code == 200 |
| 358 | + |
| 359 | + # Parse the event thread id from the API response and ensure we have one |
| 360 | + app_thread_ident = response.json().get("thread_ident") |
| 361 | + assert app_thread_ident is not None |
| 362 | + |
| 363 | + # Ensure the app thread was not the same as the rollover one and that a rollover thread exists |
| 364 | + assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads |
| 365 | + assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1 |
| 366 | + |
| 367 | + |
306 | 368 | def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None: |
307 | 369 | client = test_client_factory(app) |
308 | 370 | response = client.post( |
|
0 commit comments