Skip to content

Commit 42fc895

Browse files
committed
multithreaded download
1 parent 4c0170e commit 42fc895

File tree

1 file changed

+135
-22
lines changed

1 file changed

+135
-22
lines changed

gguf-py/gguf/utility.py

Lines changed: 135 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4-
from typing import Literal
4+
from typing import Literal, Any
55

66
import os
77
import json
8+
import requests
9+
import threading
10+
from urllib.parse import urlparse
811

912

1013
def fill_templated_filename(filename: str, output_type: str | None) -> str:
@@ -110,6 +113,10 @@ class SafetensorRemote:
110113
BASE_DOMAIN = "https://huggingface.co"
111114
ALIGNMENT = 8 # bytes
112115

116+
# start using multithread download for files larger than 100MB
117+
MULTITHREAD_THREDSHOLD = 100 * 1024 * 1024
118+
MULTITHREAD_COUNT = 8 # number of threads
119+
113120
@classmethod
114121
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
115122
"""
@@ -211,47 +218,153 @@ def get_metadata(cls, url: str) -> tuple[dict, int]:
211218
except json.JSONDecodeError as e:
212219
raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
213220

221+
@classmethod
222+
def _get_request_headers(cls) -> dict[str, str]:
223+
"""Prepare common headers for requests."""
224+
headers = {"User-Agent": "convert_hf_to_gguf"}
225+
if os.environ.get("HF_TOKEN"):
226+
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
227+
return headers
228+
214229
@classmethod
215230
def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
216231
"""
217-
Get raw byte data from a remote file by range.
218-
If size is not specified, it will read the entire file.
219-
"""
220-
import requests
221-
from urllib.parse import urlparse
232+
Get raw byte data from a remote file by range using single or multi-threaded download.
222233
234+
If size is -1, it attempts to read from 'start' to the end of the file (single-threaded only).
235+
If size is >= MULTITHREAD_THREDSHOLD, it uses multiple threads.
236+
Otherwise, it uses a single request.
237+
"""
223238
parsed_url = urlparse(url)
224239
if not parsed_url.scheme or not parsed_url.netloc:
225240
raise ValueError(f"Invalid URL: {url}")
226241

227-
headers = {}
228-
if os.environ.get("HF_TOKEN"):
229-
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
230-
if size > -1:
231-
headers["Range"] = f"bytes={start}-{start + size}"
232-
response = requests.get(url, allow_redirects=True, headers=headers)
233-
response.raise_for_status()
234-
235-
# Get raw byte data
236-
return response.content[:size]
242+
common_headers = cls._get_request_headers()
243+
244+
# --- Multithreading Path ---
245+
if size >= cls.MULTITHREAD_THREDSHOLD and cls.MULTITHREAD_COUNT > 1:
246+
# print(f"Using {cls.MULTITHREAD_COUNT} threads for size {size / (1024*1024):.2f} MB")
247+
num_threads = cls.MULTITHREAD_COUNT
248+
results: list[Any] = [None] * num_threads # Store results or exceptions
249+
threads: list[threading.Thread] = []
250+
251+
def download_chunk(chunk_url: str, chunk_start: int, chunk_size: int, index: int, result_list: list, headers: dict):
252+
"""Worker function for thread."""
253+
thread_headers = headers.copy()
254+
# Range header is inclusive end byte
255+
range_end = chunk_start + chunk_size - 1
256+
thread_headers["Range"] = f"bytes={chunk_start}-{range_end}"
257+
try:
258+
# Using stream=False should make requests wait for content download
259+
response = requests.get(chunk_url, allow_redirects=True, headers=thread_headers, stream=False, timeout=120) # Added timeout
260+
response.raise_for_status() # Check for HTTP errors
261+
262+
content = response.content
263+
if len(content) != chunk_size:
264+
# This is a critical check
265+
raise IOError(
266+
f"Thread {index}: Downloaded chunk size mismatch for range {thread_headers['Range']}. "
267+
f"Expected {chunk_size}, got {len(content)}. Status: {response.status_code}. URL: {chunk_url}"
268+
)
269+
result_list[index] = content
270+
except Exception as e:
271+
# Store exception to be raised by the main thread
272+
# print(f"Thread {index} error downloading range {thread_headers.get('Range', 'N/A')}: {e}") # Optional debug print
273+
result_list[index] = e
274+
275+
# Calculate chunk sizes and create/start threads
276+
base_chunk_size = size // num_threads
277+
remainder = size % num_threads
278+
current_offset = start
279+
280+
for i in range(num_threads):
281+
chunk_size = base_chunk_size + (1 if i < remainder else 0)
282+
if chunk_size == 0: # Should not happen if size >= threshold but handle defensively
283+
results[i] = b"" # Store empty bytes for this "chunk"
284+
continue
285+
286+
thread = threading.Thread(
287+
target=download_chunk,
288+
args=(url, current_offset, chunk_size, i, results, common_headers),
289+
daemon=True # Allow main thread to exit even if daemon threads are stuck (though join prevents this)
290+
)
291+
threads.append(thread)
292+
thread.start()
293+
current_offset += chunk_size # Move offset for the next chunk
294+
295+
# Wait for all threads to complete
296+
for i, thread in enumerate(threads):
297+
thread.join() # Wait indefinitely for each thread
298+
299+
# Check results for errors and concatenate chunks
300+
final_data_parts = []
301+
for i in range(num_threads):
302+
result = results[i]
303+
if isinstance(result, Exception):
304+
# Raise the first exception encountered
305+
raise result
306+
elif result is None:
307+
# This indicates a thread finished without setting its result or exception (unexpected)
308+
# Check if it was supposed to download anything
309+
expected_chunk_size = base_chunk_size + (1 if i < remainder else 0)
310+
if expected_chunk_size > 0:
311+
raise RuntimeError(f"Thread {i} finished without providing data or exception for a non-zero chunk.")
312+
else:
313+
final_data_parts.append(b"") # Append empty bytes for zero-size chunk
314+
else:
315+
final_data_parts.append(result)
316+
317+
# Combine the byte chunks
318+
final_data = b"".join(final_data_parts)
319+
320+
# Final validation: Does the combined size match the requested size?
321+
if len(final_data) != size:
322+
raise IOError(f"Final assembled data size mismatch. Expected {size}, got {len(final_data)}. URL: {url}, Range: {start}-{start+size-1}")
323+
324+
return final_data
325+
326+
# --- Single-threaded Path ---
327+
else:
328+
# print(f"Using single thread for size {size}") # Optional debug print
329+
headers = common_headers.copy()
330+
if size > -1:
331+
# Range header uses inclusive end byte
332+
range_end = start + size - 1
333+
headers["Range"] = f"bytes={start}-{range_end}"
334+
elif start > 0:
335+
# Request from start offset to the end of the file
336+
headers["Range"] = f"bytes={start}-"
337+
# If start=0 and size=-1, no Range header is needed (get full file)
338+
339+
response = requests.get(url, allow_redirects=True, headers=headers, stream=False, timeout=120) # Added timeout
340+
response.raise_for_status()
341+
content = response.content
342+
343+
# Validate downloaded size if a specific size was requested
344+
if size > -1 and len(content) != size:
345+
# Check status code - 206 Partial Content is expected for successful range requests
346+
status_code = response.status_code
347+
content_range = response.headers.get('Content-Range')
348+
raise IOError(
349+
f"Single thread downloaded size mismatch. Requested {size} bytes from offset {start} (Range: {headers.get('Range')}), "
350+
f"got {len(content)} bytes. Status: {status_code}, Content-Range: {content_range}. URL: {url}"
351+
)
352+
353+
return content
237354

238355
@classmethod
239356
def check_file_exist(cls, url: str) -> bool:
240357
"""
241358
Check if a file exists at the given URL.
242359
Returns True if the file exists, False otherwise.
243360
"""
244-
import requests
245-
from urllib.parse import urlparse
246-
247361
parsed_url = urlparse(url)
248362
if not parsed_url.scheme or not parsed_url.netloc:
249363
raise ValueError(f"Invalid URL: {url}")
250364

251365
try:
252-
headers = {"Range": "bytes=0-0"}
253-
if os.environ.get("HF_TOKEN"):
254-
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
366+
headers = cls._get_request_headers()
367+
headers["Range"] = "bytes=0-0" # Request a small range to check existence
255368
response = requests.head(url, allow_redirects=True, headers=headers)
256369
# Success (2xx) or redirect (3xx)
257370
return 200 <= response.status_code < 400

0 commit comments

Comments
 (0)