Skip to content

Commit ba1ee52

Browse files
committed
Stream the final file to the cache
1 parent 17f3703 commit ba1ee52

File tree

3 files changed

+54
-25
lines changed

3 files changed

+54
-25
lines changed

src/pip/_internal/network/cache.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from __future__ import annotations
44

55
import os
6+
import shutil
67
from collections.abc import Generator
78
from contextlib import contextmanager
89
from datetime import datetime
9-
from typing import BinaryIO
10+
from typing import Any, BinaryIO, Callable
1011

1112
from pip._vendor.cachecontrol.cache import SeparateBodyBaseCache
1213
from pip._vendor.cachecontrol.caches import SeparateBodyFileCache
@@ -72,12 +73,13 @@ def get(self, key: str) -> bytes | None:
7273
with open(metadata_path, "rb") as f:
7374
return f.read()
7475

75-
def _write(self, path: str, data: bytes) -> None:
76+
def _write_to_file(self, path: str, writer_func: Callable[[BinaryIO], Any]) -> None:
77+
"""Common file writing logic with proper permissions and atomic replacement."""
7678
with suppressed_cache_errors():
7779
ensure_dir(os.path.dirname(path))
7880

7981
with adjacent_tmp_file(path) as f:
80-
f.write(data)
82+
writer_func(f)
8183
# Inherit the read/write permissions of the cache directory
8284
# to enable multi-user cache use-cases.
8385
mode = (
@@ -93,6 +95,12 @@ def _write(self, path: str, data: bytes) -> None:
9395

9496
replace(f.name, path)
9597

98+
def _write(self, path: str, data: bytes) -> None:
99+
self._write_to_file(path, lambda f: f.write(data))
100+
101+
def _write_from_io(self, path: str, source_file: BinaryIO) -> None:
102+
self._write_to_file(path, lambda f: shutil.copyfileobj(source_file, f))
103+
96104
def set(
97105
self, key: str, value: bytes, expires: int | datetime | None = None
98106
) -> None:
@@ -118,3 +126,8 @@ def get_body(self, key: str) -> BinaryIO | None:
118126
def set_body(self, key: str, body: bytes) -> None:
119127
path = self._get_cache_path(key) + ".body"
120128
self._write(path, body)
129+
130+
def set_body_from_io(self, key: str, body_file: BinaryIO) -> None:
131+
"""Set the body of the cache entry from a file object."""
132+
path = self._get_cache_path(key) + ".body"
133+
self._write_from_io(path, body_file)

src/pip/_internal/network/download.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pip._internal.exceptions import IncompleteDownloadError, NetworkConnectionError
2222
from pip._internal.models.index import PyPI
2323
from pip._internal.models.link import Link
24-
from pip._internal.network.cache import is_from_cache
24+
from pip._internal.network.cache import SafeFileCache, is_from_cache
2525
from pip._internal.network.session import CacheControlAdapter, PipSession
2626
from pip._internal.network.utils import HEADERS, raise_for_status, response_chunks
2727
from pip._internal.utils.misc import format_size, redact_auth_from_url, splitext
@@ -284,6 +284,14 @@ def _cache_resumed_download(
284284
)
285285
return
286286

287+
# Check SafeFileCache is being used
288+
if not isinstance(adapter.cache, SafeFileCache):
289+
logger.debug(
290+
"Skipping resume download caching: "
291+
"cache doesn't support separate body storage"
292+
)
293+
return
294+
287295
synthetic_request = PreparedRequest()
288296
synthetic_request.prepare(method="GET", url=url, headers={})
289297

@@ -300,15 +308,17 @@ def _cache_resumed_download(
300308
preload_content=False,
301309
)
302310

303-
# Use the cache controller to store this as a complete response
311+
# Stream the file to cache
312+
cache_url = adapter.controller.cache_url(url)
313+
adapter.cache.set(
314+
cache_url,
315+
adapter.controller.serializer.dumps(
316+
synthetic_request, synthetic_response, b""
317+
),
318+
)
304319
download.output_file.flush()
305320
with open(download.output_file.name, "rb") as f:
306-
adapter.controller.cache_response(
307-
synthetic_request,
308-
synthetic_response,
309-
body=f.read(),
310-
status_codes=(200, 203, 300, 301, 308),
311-
)
321+
adapter.cache.set_body_from_io(cache_url, f)
312322

313323
logger.debug(
314324
"Cached resumed download as complete response for future use: %s", url

tests/unit/test_network_download.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from pip._internal.exceptions import IncompleteDownloadError
1111
from pip._internal.models.link import Link
12+
from pip._internal.network.cache import SafeFileCache
1213
from pip._internal.network.download import (
1314
Downloader,
1415
_get_http_response_size,
@@ -374,6 +375,13 @@ def test_resumed_download_caching(tmpdir: Path) -> None:
374375
mock_adapter = MagicMock(spec=CacheControlAdapter)
375376
mock_controller = MagicMock()
376377
mock_adapter.controller = mock_controller
378+
mock_controller.cache_url = MagicMock(return_value="cache_key")
379+
mock_controller.serializer = MagicMock()
380+
mock_controller.serializer.dumps = MagicMock(return_value=b"serialized_data")
381+
382+
# Mock the cache to be a SafeFileCache
383+
mock_cache = MagicMock(spec=SafeFileCache)
384+
mock_adapter.cache = mock_cache
377385

378386
# Create a mock for the session adapters
379387
adapters_mock = MagicMock()
@@ -392,20 +400,18 @@ def test_resumed_download_caching(tmpdir: Path) -> None:
392400
expected_bytes = b"0cfa7e9d-1868-4dd7-9fb3-f2561d5dfd89"
393401
assert downloaded_bytes == expected_bytes
394402

395-
# Verify that cache_response was called for the resumed download
396-
mock_controller.cache_response.assert_called_once()
403+
# Verify that cache.set was called for metadata
404+
mock_cache.set.assert_called_once()
405+
406+
# Verify that set_body_from_io was called for streaming the body
407+
mock_cache.set_body_from_io.assert_called_once()
397408

398-
# Get the call arguments to verify the cached content
399-
call_args = mock_controller.cache_response.call_args
400-
assert call_args is not None
409+
# Verify the call arguments
410+
set_call_args = mock_cache.set.call_args
411+
assert set_call_args[0][0] == "cache_key" # First argument should be cache_key
401412

402-
# Extract positional and keyword arguments
403-
args, kwargs = call_args
404-
request, response = args
405-
body = kwargs.get("body")
406-
status_codes = kwargs.get("status_codes")
413+
set_body_call_args = mock_cache.set_body_from_io.call_args
407414

408-
assert body == expected_bytes, "Cached body should match complete file content"
409-
assert response.status == 200, "Cached response should have status 200"
410-
assert request.url == link.url_without_fragment
411-
assert 200 in status_codes
415+
assert set_body_call_args[0][0] == "cache_key"
416+
assert hasattr(set_body_call_args[0][1], "read")
417+
assert set_body_call_args[0][1].name == filepath

0 commit comments

Comments
 (0)