|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +from dataclasses import dataclass |
3 | 4 | from typing import Literal
|
4 | 5 |
|
| 6 | +import os |
| 7 | +import json |
| 8 | + |
5 | 9 |
|
6 | 10 | def fill_templated_filename(filename: str, output_type: str | None) -> str:
|
7 | 11 | # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf'
|
@@ -67,3 +71,194 @@ def naming_convention(model_name: str | None, base_name: str | None, finetune_st
|
67 | 71 | kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else ""
|
68 | 72 |
|
69 | 73 | return f"{name}{parameters}{finetune}{version}{encoding}{kind}"
|
| 74 | + |
| 75 | + |
| 76 | +@dataclass |
| 77 | +class RemoteTensor: |
| 78 | + dtype: str |
| 79 | + shape: tuple[int, ...] |
| 80 | + offset_start: int |
| 81 | + size: int |
| 82 | + url: str |
| 83 | + |
| 84 | + def data(self) -> bytearray: |
| 85 | + # TODO: handle request errors (maybe with limited retries?) |
| 86 | + # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable |
| 87 | + data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size)) |
| 88 | + return data |
| 89 | + |
| 90 | + |
| 91 | +class SafetensorRemote: |
| 92 | + """ |
| 93 | + Uility class to handle remote safetensor files. |
| 94 | + This class is designed to work with Hugging Face model repositories. |
| 95 | +
|
| 96 | + Example (one model has single safetensor file, the other has multiple): |
| 97 | + for model_id in ["ngxson/TEST-Tiny-Llama4", "Qwen/Qwen2.5-7B-Instruct"]: |
| 98 | + tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) |
| 99 | + print(tensors) |
| 100 | +
|
| 101 | + Example reading tensor data: |
| 102 | + tensors = SafetensorRemote.get_list_tensors_hf_model(model_id) |
| 103 | + for name, meta in tensors.items(): |
| 104 | + dtype, shape, offset_start, size, remote_safetensor_url = meta |
| 105 | + # read the tensor data |
| 106 | + data = SafetensorRemote.get_data_by_range(remote_safetensor_url, offset_start, size) |
| 107 | + print(data) |
| 108 | + """ |
| 109 | + |
| 110 | + BASE_DOMAIN = "https://huggingface.co" |
| 111 | + ALIGNMENT = 8 # bytes |
| 112 | + |
| 113 | + @classmethod |
| 114 | + def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: |
| 115 | + """ |
| 116 | + Get list of tensors from a Hugging Face model repository. |
| 117 | +
|
| 118 | + Returns a dictionary of tensor names and their metadata. |
| 119 | + Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url) |
| 120 | + """ |
| 121 | + # case 1: model has only one single model.safetensor file |
| 122 | + is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors") |
| 123 | + if is_single_file: |
| 124 | + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors" |
| 125 | + return cls.get_list_tensors(url) |
| 126 | + |
| 127 | + # case 2: model has multiple files |
| 128 | + index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json" |
| 129 | + is_multiple_files = cls.check_file_exist(index_url) |
| 130 | + if is_multiple_files: |
| 131 | + # read the index file |
| 132 | + index_data = cls.get_data_by_range(index_url, 0) |
| 133 | + index_str = index_data.decode('utf-8') |
| 134 | + index_json = json.loads(index_str) |
| 135 | + assert index_json.get("weight_map") is not None, "weight_map not found in index file" |
| 136 | + weight_map = index_json["weight_map"] |
| 137 | + # get the list of files |
| 138 | + all_files = list(set(weight_map.values())) |
| 139 | + all_files.sort() # make sure we load shard files in order |
| 140 | + # get the list of tensors |
| 141 | + tensors: dict[str, RemoteTensor] = {} |
| 142 | + for file in all_files: |
| 143 | + url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/{file}" |
| 144 | + for key, val in cls.get_list_tensors(url).items(): |
| 145 | + tensors[key] = val |
| 146 | + return tensors |
| 147 | + |
| 148 | + raise ValueError(f"Model {model_id} does not have any safetensor files") |
| 149 | + |
| 150 | + @classmethod |
| 151 | + def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]: |
| 152 | + """ |
| 153 | + Get list of tensors from a remote safetensor file. |
| 154 | +
|
| 155 | + Returns a dictionary of tensor names and their metadata. |
| 156 | + Each tensor is represented as a tuple of (dtype, shape, offset_start, size) |
| 157 | + """ |
| 158 | + metadata, data_start_offset = cls.get_metadata(url) |
| 159 | + res: dict[str, RemoteTensor] = {} |
| 160 | + |
| 161 | + for name, meta in metadata.items(): |
| 162 | + if name == "__metadata__": |
| 163 | + continue |
| 164 | + if not isinstance(meta, dict): |
| 165 | + raise ValueError(f"Invalid metadata for tensor '{name}': {meta}") |
| 166 | + try: |
| 167 | + dtype = meta["dtype"] |
| 168 | + shape = meta["shape"] |
| 169 | + offset_start_relative, offset_end_relative = meta["data_offsets"] |
| 170 | + size = offset_end_relative - offset_start_relative |
| 171 | + offset_start = data_start_offset + offset_start_relative |
| 172 | + res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url) |
| 173 | + except KeyError as e: |
| 174 | + raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") |
| 175 | + |
| 176 | + return res |
| 177 | + |
| 178 | + @classmethod |
| 179 | + def get_metadata(cls, url: str) -> tuple[dict, int]: |
| 180 | + """ |
| 181 | + Get JSON metadata from a remote safetensor file. |
| 182 | +
|
| 183 | + Returns tuple of (metadata, data_start_offset) |
| 184 | + """ |
| 185 | + # Request first 5MB of the file (hopefully enough for metadata) |
| 186 | + read_size = 5 * 1024 * 1024 |
| 187 | + raw_data = cls.get_data_by_range(url, 0, read_size) |
| 188 | + |
| 189 | + # Parse header |
| 190 | + # First 8 bytes contain the metadata length as u64 little-endian |
| 191 | + if len(raw_data) < 8: |
| 192 | + raise ValueError("Not enough data to read metadata size") |
| 193 | + metadata_length = int.from_bytes(raw_data[:8], byteorder='little') |
| 194 | + |
| 195 | + # Calculate the data start offset |
| 196 | + data_start_offset = 8 + metadata_length |
| 197 | + alignment = SafetensorRemote.ALIGNMENT |
| 198 | + if data_start_offset % alignment != 0: |
| 199 | + data_start_offset += alignment - (data_start_offset % alignment) |
| 200 | + |
| 201 | + # Check if we have enough data to read the metadata |
| 202 | + if len(raw_data) < 8 + metadata_length: |
| 203 | + raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}") |
| 204 | + |
| 205 | + # Extract metadata bytes and parse as JSON |
| 206 | + metadata_bytes = raw_data[8:8 + metadata_length] |
| 207 | + metadata_str = metadata_bytes.decode('utf-8') |
| 208 | + try: |
| 209 | + metadata = json.loads(metadata_str) |
| 210 | + return metadata, data_start_offset |
| 211 | + except json.JSONDecodeError as e: |
| 212 | + raise ValueError(f"Failed to parse safetensor metadata as JSON: {e}") |
| 213 | + |
| 214 | + @classmethod |
| 215 | + def get_data_by_range(cls, url: str, start: int, size: int = -1) -> bytes: |
| 216 | + """ |
| 217 | + Get raw byte data from a remote file by range. |
| 218 | + If size is not specified, it will read the entire file. |
| 219 | + """ |
| 220 | + import requests |
| 221 | + from urllib.parse import urlparse |
| 222 | + |
| 223 | + parsed_url = urlparse(url) |
| 224 | + if not parsed_url.scheme or not parsed_url.netloc: |
| 225 | + raise ValueError(f"Invalid URL: {url}") |
| 226 | + |
| 227 | + headers = cls._get_request_headers() |
| 228 | + if size > -1: |
| 229 | + headers["Range"] = f"bytes={start}-{start + size}" |
| 230 | + response = requests.get(url, allow_redirects=True, headers=headers) |
| 231 | + response.raise_for_status() |
| 232 | + |
| 233 | + # Get raw byte data |
| 234 | + return response.content[:size] |
| 235 | + |
| 236 | + @classmethod |
| 237 | + def check_file_exist(cls, url: str) -> bool: |
| 238 | + """ |
| 239 | + Check if a file exists at the given URL. |
| 240 | + Returns True if the file exists, False otherwise. |
| 241 | + """ |
| 242 | + import requests |
| 243 | + from urllib.parse import urlparse |
| 244 | + |
| 245 | + parsed_url = urlparse(url) |
| 246 | + if not parsed_url.scheme or not parsed_url.netloc: |
| 247 | + raise ValueError(f"Invalid URL: {url}") |
| 248 | + |
| 249 | + try: |
| 250 | + headers = cls._get_request_headers() |
| 251 | + headers["Range"] = "bytes=0-0" |
| 252 | + response = requests.head(url, allow_redirects=True, headers=headers) |
| 253 | + # Success (2xx) or redirect (3xx) |
| 254 | + return 200 <= response.status_code < 400 |
| 255 | + except requests.RequestException: |
| 256 | + return False |
| 257 | + |
| 258 | + @classmethod |
| 259 | + def _get_request_headers(cls) -> dict[str, str]: |
| 260 | + """Prepare common headers for requests.""" |
| 261 | + headers = {"User-Agent": "convert_hf_to_gguf"} |
| 262 | + if os.environ.get("HF_TOKEN"): |
| 263 | + headers["Authorization"] = f"Bearer {os.environ['HF_TOKEN']}" |
| 264 | + return headers |
0 commit comments