Skip to content

Commit 9f7ec2e

Browse files
HonakerMKludex
andauthored
Make UploadFile check for future rollover (#2962)
Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 540ff5f commit 9f7ec2e

File tree

2 files changed

+82
-6
lines changed

2 files changed

+82
-6
lines changed

starlette/datastructures.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,10 @@ def __init__(
428428
self.size = size
429429
self.headers = headers or Headers()
430430

431+
# Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks.
432+
# Note 0 means unlimited mirroring SpooledTemporaryFile's __init__
433+
self._max_mem_size = getattr(self.file, "_max_size", 0)
434+
431435
@property
432436
def content_type(self) -> str | None:
433437
return self.headers.get("content-type", None)
@@ -438,14 +442,24 @@ def _in_memory(self) -> bool:
438442
rolled_to_disk = getattr(self.file, "_rolled", True)
439443
return not rolled_to_disk
440444

445+
def _will_roll(self, size_to_add: int) -> bool:
446+
# If we're not in_memory then we will always roll
447+
if not self._in_memory:
448+
return True
449+
450+
# Check for SpooledTemporaryFile._max_size
451+
future_size = self.file.tell() + size_to_add
452+
return bool(future_size > self._max_mem_size) if self._max_mem_size else False
453+
441454
async def write(self, data: bytes) -> None:
455+
new_data_len = len(data)
442456
if self.size is not None:
443-
self.size += len(data)
457+
self.size += new_data_len
444458

445-
if self._in_memory:
446-
self.file.write(data)
447-
else:
459+
if self._will_roll(new_data_len):
448460
await run_in_threadpool(self.file.write, data)
461+
else:
462+
self.file.write(data)
449463

450464
async def read(self, size: int = -1) -> bytes:
451465
if self._in_memory:

tests/test_formparsers.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
from __future__ import annotations
22

33
import os
4+
import threading
5+
from collections.abc import Generator
46
from contextlib import AbstractContextManager, nullcontext as does_not_raise
7+
from io import BytesIO
58
from pathlib import Path
6-
from typing import Any
9+
from tempfile import SpooledTemporaryFile
10+
from typing import Any, ClassVar
11+
from unittest import mock
712

813
import pytest
914

1015
from starlette.applications import Starlette
1116
from starlette.datastructures import UploadFile
12-
from starlette.formparsers import MultiPartException, _user_safe_decode
17+
from starlette.formparsers import MultiPartException, MultiPartParser, _user_safe_decode
1318
from starlette.requests import Request
1419
from starlette.responses import JSONResponse
1520
from starlette.routing import Mount
@@ -104,6 +109,22 @@ async def app_read_body(scope: Scope, receive: Receive, send: Send) -> None:
104109
await response(scope, receive, send)
105110

106111

112+
async def app_monitor_thread(scope: Scope, receive: Receive, send: Send) -> None:
113+
"""Helper app to monitor what thread the app was called on.
114+
115+
This can later be used to validate thread/event loop operations.
116+
"""
117+
request = Request(scope, receive)
118+
119+
# Make sure we parse the form
120+
await request.form()
121+
await request.close()
122+
123+
# Send back the current thread id
124+
response = JSONResponse({"thread_ident": threading.current_thread().ident})
125+
await response(scope, receive, send)
126+
127+
107128
def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000, max_part_size: int = 1024 * 1024) -> ASGIApp:
108129
async def app(scope: Scope, receive: Receive, send: Send) -> None:
109130
request = Request(scope, receive)
@@ -303,6 +324,47 @@ def test_multipart_request_mixed_files_and_data(tmpdir: Path, test_client_factor
303324
}
304325

305326

327+
class ThreadTrackingSpooledTemporaryFile(SpooledTemporaryFile[bytes]):
328+
"""Helper class to track which threads performed the rollover operation.
329+
330+
This is not threadsafe/multi-test safe.
331+
"""
332+
333+
rollover_threads: ClassVar[set[int | None]] = set()
334+
335+
def rollover(self) -> None:
336+
ThreadTrackingSpooledTemporaryFile.rollover_threads.add(threading.current_thread().ident)
337+
super().rollover()
338+
339+
340+
@pytest.fixture
341+
def mock_spooled_temporary_file() -> Generator[None]:
342+
try:
343+
with mock.patch("starlette.formparsers.SpooledTemporaryFile", ThreadTrackingSpooledTemporaryFile):
344+
yield
345+
finally:
346+
ThreadTrackingSpooledTemporaryFile.rollover_threads.clear()
347+
348+
349+
def test_multipart_request_large_file_rollover_in_background_thread(
350+
mock_spooled_temporary_file: None, test_client_factory: TestClientFactory
351+
) -> None:
352+
"""Test that Spooled file rollovers happen in background threads."""
353+
data = BytesIO(b" " * (MultiPartParser.spool_max_size + 1))
354+
355+
client = test_client_factory(app_monitor_thread)
356+
response = client.post("/", files=[("test_large", data)])
357+
assert response.status_code == 200
358+
359+
# Parse the event thread id from the API response and ensure we have one
360+
app_thread_ident = response.json().get("thread_ident")
361+
assert app_thread_ident is not None
362+
363+
# Ensure the app thread was not the same as the rollover one and that a rollover thread exists
364+
assert app_thread_ident not in ThreadTrackingSpooledTemporaryFile.rollover_threads
365+
assert len(ThreadTrackingSpooledTemporaryFile.rollover_threads) == 1
366+
367+
306368
def test_multipart_request_with_charset_for_filename(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
307369
client = test_client_factory(app)
308370
response = client.post(

0 commit comments

Comments
 (0)