Skip to content

Commit 64eda5d

Browse files
ngxsoncompilade
andauthored
convert : ability to lazy-load safetensors remotely without downloading to disk (#12820)
* gguf util : add SafetensorRemote * fix style * convert: add --remote option * convert : allow using lazy remote tensors It's a bit slow for now since everything is blocking and single-threaded. * correct metadata.name * small style fix * support HF_TOKEN * convert : use writeable buffer for remote lazy tensors * convert : fix flake8 lint regarding lamdba assigment * multithreaded download * multithread: print debug * fix style * Revert "multithreaded download" This reverts commit 42fc895. * bring back _get_request_headers --------- Co-authored-by: Francis Couture-Harpin <[email protected]>
1 parent fe5b78c commit 64eda5d

File tree

2 files changed

+244
-7
lines changed

2 files changed

+244
-7
lines changed

convert_hf_to_gguf.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class Model:
6565
model_name: str | None
6666
metadata_override: Path | None
6767
dir_model_card: Path
68+
remote_hf_model_id: str | None
6869

6970
# subclasses should define this!
7071
model_arch: gguf.MODEL_ARCH
@@ -73,7 +74,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
7374
use_temp_file: bool = False, eager: bool = False,
7475
metadata_override: Path | None = None, model_name: str | None = None,
7576
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
76-
small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
77+
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
7778
if type(self) is Model:
7879
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
7980

@@ -83,11 +84,24 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
8384
self.is_big_endian = is_big_endian
8485
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
8586
self.use_temp_file = use_temp_file
86-
self.lazy = not eager
87-
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
88-
self.is_safetensors = len(self.part_names) > 0
89-
if not self.is_safetensors:
90-
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
87+
self.lazy = not eager or (remote_hf_model_id is not None)
88+
self.remote_hf_model_id = remote_hf_model_id
89+
if remote_hf_model_id is not None:
90+
self.is_safetensors = True
91+
92+
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
93+
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
94+
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
95+
self.tensor_names = set(name for name in remote_tensors.keys())
96+
for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items():
97+
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
98+
99+
self.get_tensors = get_remote_tensors
100+
else:
101+
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
102+
self.is_safetensors = len(self.part_names) > 0
103+
if not self.is_safetensors:
104+
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
91105
self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams
92106
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
93107
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
@@ -393,6 +407,10 @@ def prepare_metadata(self, vocab_only: bool):
393407

394408
self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params)
395409

410+
# If we are using HF model id, set the metadata name to the model id
411+
if self.remote_hf_model_id:
412+
self.metadata.name = self.remote_hf_model_id
413+
396414
# Fallback to model directory name if metadata name is still missing
397415
if self.metadata.name is None:
398416
self.metadata.name = self.dir_model.name
@@ -5403,6 +5421,14 @@ def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
54035421
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
54045422
return cast(torch.Tensor, lazy)
54055423

5424+
@classmethod
5425+
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
5426+
dtype = cls._dtype_str_map[remote_tensor.dtype]
5427+
shape = remote_tensor.shape
5428+
meta = cls.meta_with_dtype_and_shape(dtype, shape)
5429+
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape))
5430+
return cast(torch.Tensor, lazy)
5431+
54065432
@classmethod
54075433
def __torch_function__(cls, func, types, args=(), kwargs=None):
54085434
del types # unused
@@ -5480,6 +5506,10 @@ def parse_args() -> argparse.Namespace:
54805506
"--print-supported-models", action="store_true",
54815507
help="Print the supported models"
54825508
)
5509+
parser.add_argument(
5510+
"--remote", action="store_true",
5511+
help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.",
5512+
)
54835513

54845514
args = parser.parse_args()
54855515
if not args.print_supported_models and args.model is None:
@@ -5520,6 +5550,14 @@ def main() -> None:
55205550

55215551
dir_model = args.model
55225552

5553+
if args.remote:
5554+
from huggingface_hub import snapshot_download
5555+
local_dir = snapshot_download(
5556+
repo_id=str(dir_model),
5557+
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
5558+
dir_model = Path(local_dir)
5559+
logger.info(f"Downloaded config and tokenizer to {local_dir}")
5560+
55235561
if not dir_model.is_dir():
55245562
logger.error(f'Error: {args.model} is not a directory')
55255563
sys.exit(1)
@@ -5541,6 +5579,9 @@ def main() -> None:
55415579

55425580
if args.outfile is not None:
55435581
fname_out = args.outfile
5582+
elif args.remote:
5583+
# if remote, use the model ID as the output file name
5584+
fname_out = Path("./" + str(args.model).replace("/", "-") + "-{ftype}.gguf")
55445585
else:
55455586
fname_out = dir_model
55465587

@@ -5564,7 +5605,8 @@ def main() -> None:
55645605
metadata_override=args.metadata, model_name=args.model_name,
55655606
split_max_tensors=args.split_max_tensors,
55665607
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
5567-
small_first_shard=args.no_tensor_first_split)
5608+
small_first_shard=args.no_tensor_first_split,
5609+
remote_hf_model_id=str(args.model) if args.remote else None)
55685610

55695611
if args.vocab_only:
55705612
logger.info("Exporting model vocab...")

gguf-py/gguf/utility.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from __future__ import annotations
22

3+
from dataclasses import dataclass
34
from typing import Literal
45

6+
import os
7+
import json
8+
59

610
def fill_templated_filename(filename: str, output_type: str | None) -> str:
711
# Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
@@ -67,3 +71,194 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
6771
kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
6872

6973
return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
74+
75+
76+
@dataclass
77+
class RemoteTensor:
78+
dtype: str
79+
shape: tuple[int, ...]
80+
offset_start: int
81+
size: int
82+
url: str
83+
84+
def data(self) -> bytearray:
85+
# TODO: handle request errors (maybe with limited retries?)
86+
# NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable
87+
data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size))
88+
return data
89+
90+
91+
class SafetensorRemote:
92+
"""
93+
Uility class to handle remote safetensor files.
94+
This class is designed to work with Hugging Face model repositories.
95+
96+
Example (one model has single safetensor file, the other has multiple):
97+
for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]:
98+
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
99+
print(tensors)
100+
101+
Example reading tensor data:
102+
tensors = SafetensorRemote.get_list_tensors_hf_model(model_id)
103+
for name, meta in tensors.items():
104+
dtype, shape, offset_start, size, remote_safetensor_url = meta
105+
# read the tensor data
106+
data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size)
107+
print(data)
108+
"""
109+
110+
BASE_DOMAIN = "https://huggingface.co"
111+
ALIGNMENT = 8 # bytes
112+
113+
@classmethod
114+
def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]:
115+
"""
116+
Get list of tensors from a Hugging Face model repository.
117+
118+
Returns a dictionary of tensor names and their metadata.
119+
Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url)
120+
"""
121+
# case 1: model has only one single model.safetensor file
122+
is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors")
123+
if is_single_file:
124+
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors"
125+
return cls.get_list_tensors(url)
126+
127+
# case 2: model has multiple files
128+
index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json"
129+
is_multiple_files = cls.check_file_exist(index_url)
130+
if is_multiple_files:
131+
# read the index file
132+
index_data = cls.get_data_by_range(index_url, 0)
133+
index_str = index_data.decode('utf-8')
134+
index_json = json.loads(index_str)
135+
assert index_json.get("weight_map") is not None, "weight_map not found in index file"
136+
weight_map = index_json["weight_map"]
137+
# get the list of files
138+
all_files = list(set(weight_map.values()))
139+
all_files.sort() # make sure we load shard files in order
140+
# get the list of tensors
141+
tensors: dict[str, RemoteTensor] = {}
142+
for file in all_files:
143+
url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}"
144+
for key, val in cls.get_list_tensors(url).items():
145+
tensors[key] = val
146+
return tensors
147+
148+
raise ValueError(f"Model {model_id} does not have any safetensor files")
149+
150+
@classmethod
151+
def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]:
152+
"""
153+
Get list of tensors from a remote safetensor file.
154+
155+
Returns a dictionary of tensor names and their metadata.
156+
Each tensor is represented as a tuple of (dtype, shape, offset_start, size)
157+
"""
158+
metadata, data_start_offset = cls.get_metadata(url)
159+
res: dict[str, RemoteTensor] = {}
160+
161+
for name, meta in metadata.items():
162+
if name == "__metadata__":
163+
continue
164+
if not isinstance(meta, dict):
165+
raise ValueError(f"Invalid metadata for tensor '{name}': {meta}")
166+
try:
167+
dtype = meta["dtype"]
168+
shape = meta["shape"]
169+
offset_start_relative, offset_end_relative = meta["data_offsets"]
170+
size = offset_end_relative - offset_start_relative
171+
offset_start = data_start_offset + offset_start_relative
172+
res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url)
173+
except KeyError as e:
174+
raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}")
175+
176+
return res
177+
178+
@classmethod
179+
def get_metadata(cls, url: str) -> tuple[dict, int]:
180+
"""
181+
Get JSON metadata from a remote safetensor file.
182+
183+
Returns tuple of (metadata, data_start_offset)
184+
"""
185+
# Request first 5MB of the file (hopefully enough for metadata)
186+
read_size = 5 * 1024 * 1024
187+
raw_data = cls.get_data_by_range(url, 0, read_size)
188+
189+
# Parse header
190+
# First 8 bytes contain the metadata length as u64 little-endian
191+
if len(raw_data) < 8:
192+
raise ValueError("Not enough data to read metadata size")
193+
metadata_length = int.from_bytes(raw_data[:8], byteorder='little')
194+
195+
# Calculate the data start offset
196+
data_start_offset = 8 + metadata_length
197+
alignment = SafetensorRemote.ALIGNMENT
198+
if data_start_offset % alignment != 0:
199+
data_start_offset += alignment - (data_start_offset % alignment)
200+
201+
# Check if we have enough data to read the metadata
202+
if len(raw_data) < 8 + metadata_length:
203+
raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}")
204+
205+
# Extract metadata bytes and parse as JSON
206+
metadata_bytes = raw_data[8:8 + metadata_length]
207+
metadata_str = metadata_bytes.decode('utf-8')
208+
try:
209+
metadata = json.loads(metadata_str)
210+
return metadata, data_start_offset
211+
except json.JSONDecodeError as e:
212+
raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}")
213+
214+
@classmethod
215+
def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes:
216+
"""
217+
Get raw byte data from a remote file by range.
218+
If size is not specified, it will read the entire file.
219+
"""
220+
import requests
221+
from urllib.parse import urlparse
222+
223+
parsed_url = urlparse(url)
224+
if not parsed_url.scheme or not parsed_url.netloc:
225+
raise ValueError(f"Invalid URL: {url}")
226+
227+
headers = cls._get_request_headers()
228+
if size > -1:
229+
headers["Range"] = f"bytes={start}-{start + size}"
230+
response = requests.get(url, allow_redirects=True, headers=headers)
231+
response.raise_for_status()
232+
233+
# Get raw byte data
234+
return response.content[:size]
235+
236+
@classmethod
237+
def check_file_exist(cls, url: str) -> bool:
238+
"""
239+
Check if a file exists at the given URL.
240+
Returns True if the file exists, False otherwise.
241+
"""
242+
import requests
243+
from urllib.parse import urlparse
244+
245+
parsed_url = urlparse(url)
246+
if not parsed_url.scheme or not parsed_url.netloc:
247+
raise ValueError(f"Invalid URL: {url}")
248+
249+
try:
250+
headers = cls._get_request_headers()
251+
headers["Range"] = "bytes=0-0"
252+
response = requests.head(url, allow_redirects=True, headers=headers)
253+
# Success (2xx) or redirect (3xx)
254+
return 200 <= response.status_code < 400
255+
except requests.RequestException:
256+
return False
257+
258+
@classmethod
259+
def _get_request_headers(cls) -> dict[str, str]:
260+
"""Prepare common headers for requests."""
261+
headers = {"User-Agent": "convert_hf_to_gguf"}
262+
if os.environ.get("HF_TOKEN"):
263+
headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}"
264+
return headers

0 commit comments

Comments
 (0)