|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | from dataclasses import dataclass
|
4 |
| -from typing import Literal |
| 4 | +from typing import Literal, Any |
5 | 5 |
|
6 | 6 | import os
|
7 | 7 | import json
|
| 8 | +import requests |
| 9 | +import threading |
| 10 | +from urllib.parse import urlparse |
8 | 11 |
|
9 | 12 |
|
10 | 13 | def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
@@ -110,6 +113,10 @@ class SafetensorRemote:
|
110 | 113 | BASE_DOMAIN = "https://huggingface.co"
|
111 | 114 | ALIGNMENT = 8 # bytes
|
112 | 115 |
|
| 116 | + # start using multithread download for files larger than 100MB |
| 117 | + MULTITHREAD_THREDSHOLD = 100 * 1024 * 1024 |
| 118 | + MULTITHREAD_COUNT = 8 # number of threads |
| 119 | + |
113 | 120 | @classmethod
|
114 | 121 | def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
|
115 | 122 | """
|
@@ -211,47 +218,153 @@ def get_metadata(cls, url: str) -> tuple[dict, int]:
|
211 | 218 | except json.JSONDecodeError as e:
|
212 | 219 | raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
|
213 | 220 |
|
| 221 | + @classmethod |
| 222 | + def _get_request_headers(cls) -> dict[str, str]: |
| 223 | + """Prepare common headers for requests.""" |
| 224 | + headers = {"User-Agent": "convert_hf_to_gguf"} |
| 225 | + if os.environ.get("HF_TOKEN"): |
| 226 | + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" |
| 227 | + return headers |
| 228 | + |
214 | 229 | @classmethod
|
215 | 230 | def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
|
216 | 231 | """
|
217 |
| - Get raw byte data from a remote file by range. |
218 |
| - If size is not specified, it will read the entire file. |
219 |
| - """ |
220 |
| - import requests |
221 |
| - from urllib.parse import urlparse |
| 232 | + Get raw byte data from a remote file by range using single or multi-threaded download. |
222 | 233 |
|
| 234 | + If size is -1, it attempts to read from 'start' to the end of the file (single-threaded only). |
| 235 | + If size is >= MULTITHREAD_THREDSHOLD, it uses multiple threads. |
| 236 | + Otherwise, it uses a single request. |
| 237 | + """ |
223 | 238 | parsed_url = urlparse(url)
|
224 | 239 | if not parsed_url.scheme or not parsed_url.netloc:
|
225 | 240 | raise ValueError(f"Invalid URL: {url}")
|
226 | 241 |
|
227 |
| - headers = {} |
228 |
| - if os.environ.get("HF_TOKEN"): |
229 |
| - headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" |
230 |
| - if size > -1: |
231 |
| - headers["Range"] = f"bytes={start}-{start + size}" |
232 |
| - response = requests.get(url, allow_redirects=True, headers=headers) |
233 |
| - response.raise_for_status() |
234 |
| - |
235 |
| - # Get raw byte data |
236 |
| - return response.content[:size] |
| 242 | + common_headers = cls._get_request_headers() |
| 243 | + |
| 244 | + # --- Multithreading Path --- |
| 245 | + if size >= cls.MULTITHREAD_THREDSHOLD and cls.MULTITHREAD_COUNT > 1: |
| 246 | + # print(f"Using {cls.MULTITHREAD_COUNT} threads for size {size / (1024*1024):.2f} MB") |
| 247 | + num_threads = cls.MULTITHREAD_COUNT |
| 248 | + results: list[Any] = [None] * num_threads # Store results or exceptions |
| 249 | + threads: list[threading.Thread] = [] |
| 250 | + |
| 251 | + def download_chunk(chunk_url: str, chunk_start: int, chunk_size: int, index: int, result_list: list, headers: dict): |
| 252 | + """Worker function for thread.""" |
| 253 | + thread_headers = headers.copy() |
| 254 | + # Range header is inclusive end byte |
| 255 | + range_end = chunk_start + chunk_size - 1 |
| 256 | + thread_headers["Range"] = f"bytes={chunk_start}-{range_end}" |
| 257 | + try: |
| 258 | + # Using stream=False should make requests wait for content download |
| 259 | + response = requests.get(chunk_url, allow_redirects=True, headers=thread_headers, stream=False, timeout=120) # Added timeout |
| 260 | + response.raise_for_status() # Check for HTTP errors |
| 261 | + |
| 262 | + content = response.content |
| 263 | + if len(content) != chunk_size: |
| 264 | + # This is a critical check |
| 265 | + raise IOError( |
| 266 | + f"Thread {index}: Downloaded chunk size mismatch for range {thread_headers['Range']}. " |
| 267 | + f"Expected {chunk_size}, got {len(content)}. Status: {response.status_code}. URL: {chunk_url}" |
| 268 | + ) |
| 269 | + result_list[index] = content |
| 270 | + except Exception as e: |
| 271 | + # Store exception to be raised by the main thread |
| 272 | + # print(f"Thread {index} error downloading range {thread_headers.get('Range', 'N/A')}: {e}") # Optional debug print |
| 273 | + result_list[index] = e |
| 274 | + |
| 275 | + # Calculate chunk sizes and create/start threads |
| 276 | + base_chunk_size = size // num_threads |
| 277 | + remainder = size % num_threads |
| 278 | + current_offset = start |
| 279 | + |
| 280 | + for i in range(num_threads): |
| 281 | + chunk_size = base_chunk_size + (1 if i < remainder else 0) |
| 282 | + if chunk_size == 0: # Should not happen if size >= threshold but handle defensively |
| 283 | + results[i] = b"" # Store empty bytes for this "chunk" |
| 284 | + continue |
| 285 | + |
| 286 | + thread = threading.Thread( |
| 287 | + target=download_chunk, |
| 288 | + args=(url, current_offset, chunk_size, i, results, common_headers), |
| 289 | + daemon=True # Allow main thread to exit even if daemon threads are stuck (though join prevents this) |
| 290 | + ) |
| 291 | + threads.append(thread) |
| 292 | + thread.start() |
| 293 | + current_offset += chunk_size # Move offset for the next chunk |
| 294 | + |
| 295 | + # Wait for all threads to complete |
| 296 | + for i, thread in enumerate(threads): |
| 297 | + thread.join() # Wait indefinitely for each thread |
| 298 | + |
| 299 | + # Check results for errors and concatenate chunks |
| 300 | + final_data_parts = [] |
| 301 | + for i in range(num_threads): |
| 302 | + result = results[i] |
| 303 | + if isinstance(result, Exception): |
| 304 | + # Raise the first exception encountered |
| 305 | + raise result |
| 306 | + elif result is None: |
| 307 | + # This indicates a thread finished without setting its result or exception (unexpected) |
| 308 | + # Check if it was supposed to download anything |
| 309 | + expected_chunk_size = base_chunk_size + (1 if i < remainder else 0) |
| 310 | + if expected_chunk_size > 0: |
| 311 | + raise RuntimeError(f"Thread {i} finished without providing data or exception for a non-zero chunk.") |
| 312 | + else: |
| 313 | + final_data_parts.append(b"") # Append empty bytes for zero-size chunk |
| 314 | + else: |
| 315 | + final_data_parts.append(result) |
| 316 | + |
| 317 | + # Combine the byte chunks |
| 318 | + final_data = b"".join(final_data_parts) |
| 319 | + |
| 320 | + # Final validation: Does the combined size match the requested size? |
| 321 | + if len(final_data) != size: |
| 322 | + raise IOError(f"Final assembled data size mismatch. Expected {size}, got {len(final_data)}. URL: {url}, Range: {start}-{start+size-1}") |
| 323 | + |
| 324 | + return final_data |
| 325 | + |
| 326 | + # --- Single-threaded Path --- |
| 327 | + else: |
| 328 | + # print(f"Using single thread for size {size}") # Optional debug print |
| 329 | + headers = common_headers.copy() |
| 330 | + if size > -1: |
| 331 | + # Range header uses inclusive end byte |
| 332 | + range_end = start + size - 1 |
| 333 | + headers["Range"] = f"bytes={start}-{range_end}" |
| 334 | + elif start > 0: |
| 335 | + # Request from start offset to the end of the file |
| 336 | + headers["Range"] = f"bytes={start}-" |
| 337 | + # If start=0 and size=-1, no Range header is needed (get full file) |
| 338 | + |
| 339 | + response = requests.get(url, allow_redirects=True, headers=headers, stream=False, timeout=120) # Added timeout |
| 340 | + response.raise_for_status() |
| 341 | + content = response.content |
| 342 | + |
| 343 | + # Validate downloaded size if a specific size was requested |
| 344 | + if size > -1 and len(content) != size: |
| 345 | + # Check status code - 206 Partial Content is expected for successful range requests |
| 346 | + status_code = response.status_code |
| 347 | + content_range = response.headers.get('Content-Range') |
| 348 | + raise IOError( |
| 349 | + f"Single thread downloaded size mismatch. Requested {size} bytes from offset {start} (Range: {headers.get('Range')}), " |
| 350 | + f"got {len(content)} bytes. Status: {status_code}, Content-Range: {content_range}. URL: {url}" |
| 351 | + ) |
| 352 | + |
| 353 | + return content |
237 | 354 |
|
238 | 355 | @classmethod
|
239 | 356 | def check_file_exist(cls, url: str) -> bool:
|
240 | 357 | """
|
241 | 358 | Check if a file exists at the given URL.
|
242 | 359 | Returns True if the file exists, False otherwise.
|
243 | 360 | """
|
244 |
| - import requests |
245 |
| - from urllib.parse import urlparse |
246 |
| - |
247 | 361 | parsed_url = urlparse(url)
|
248 | 362 | if not parsed_url.scheme or not parsed_url.netloc:
|
249 | 363 | raise ValueError(f"Invalid URL: {url}")
|
250 | 364 |
|
251 | 365 | try:
|
252 |
| - headers = {"Range": "bytes=0-0"} |
253 |
| - if os.environ.get("HF_TOKEN"): |
254 |
| - headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" |
| 366 | + headers = cls._get_request_headers() |
| 367 | + headers["Range"] = "bytes=0-0" # Request a small range to check existence |
255 | 368 | response = requests.head(url, allow_redirects=True, headers=headers)
|
256 | 369 | # Success (2xx) or redirect (3xx)
|
257 | 370 | return 200 <= response.status_code < 400
|
|
0 commit comments