1
1
"""A weight loader for HuggingFace's PyTorch format"""
2
- import dataclasses
2
+
3
3
import gc
4
4
import json
5
5
import logging
6
- import time
7
6
from collections import OrderedDict , defaultdict
8
- from contextlib import contextmanager
9
7
from pathlib import Path
10
- from typing import Dict , Iterator , List , Set , Tuple
8
+ from typing import Dict , Iterator , List , Tuple
11
9
12
10
import numpy as np
13
11
from tqdm import tqdm
14
12
from tvm .runtime import NDArray
15
13
from tvm .runtime .ndarray import array as as_ndarray
16
14
17
15
from .mapping import ExternMapping
16
+ from .stats import Stats
17
+ from .utils import check_parameter_usage , load_safetensor_shard , load_torch_shard
18
18
19
19
logger = logging .getLogger (__name__ )
20
20
21
21
22
- @dataclasses .dataclass
23
- class Stats :
24
- """Statistics of the loading process of HuggingFace PyTorch loader.
25
-
26
- Attributes
27
- ----------
28
- load_time_sec : float
29
- Time used in loading the parameters.
30
-
31
- map_time_sec : float
32
- Time used in applying the mapping function, i.e. `ExternMapping.map_func`.
33
-
34
- quant_time_sec : float
35
- Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`.
36
-
37
- current_memory_gb : float
38
- The current RAM usage in GB.
39
-
40
- total_memory_gb : float
41
- The total size data loaded from disk in GB.
42
-
43
- max_memory_gb : float
44
- The maximum RAM usage in GB.
45
- """
46
-
47
- load_time_sec : float = 0.0
48
- map_time_sec : float = 0.0
49
- quant_time_sec : float = 0.0
50
-
51
- current_memory_gb : float = 0.0
52
- total_memory_gb : float = 0.0
53
- max_memory_gb : float = 0.0
54
-
55
- def timer (self , attr ):
56
- """A context manager to time the scope and add the time to the attribute."""
57
-
58
- @contextmanager
59
- def timed_scope ():
60
- start_time = time .time ()
61
- yield
62
- elapsed_time = time .time () - start_time
63
- setattr (self , attr , getattr (self , attr ) + elapsed_time )
64
-
65
- return timed_scope ()
66
-
67
- def mem_add (self , nbytes : int ):
68
- """Add the memory usage by the given number of bytes."""
69
- mem_gb = float (nbytes ) / float (1024 ** 3 )
70
- self .current_memory_gb += mem_gb
71
- self .total_memory_gb += mem_gb
72
- self .max_memory_gb = max (self .max_memory_gb , self .current_memory_gb )
73
-
74
- def mem_rm (self , nbytes : int ):
75
- """Remove the memory usage by the given number of bytes."""
76
- mem_gb = float (nbytes ) / float (1024 ** 3 )
77
- self .current_memory_gb -= mem_gb
78
-
79
-
80
- class HFTorchLoader : # pylint: disable=too-few-public-methods
81
- """A loader loading HuggingFace's PyTorch format and converts them to MLC's parameters.
22
+ class HFLoader : # pylint: disable=too-few-public-methods
23
+ """A loader loading HuggingFace's PyTorch/SafeTensor format and converts them
24
+ to MLC's parameters.
82
25
83
26
Attributes
84
27
----------
85
28
stats : Stats
86
29
Statistics of the loading process.
87
30
88
31
extern_param_map : ExternMapping
89
- The parameter mapping from MLC to HuggingFace PyTorch.
32
+ The parameter mapping from MLC to HuggingFace PyTorch/SafeTensor .
90
33
91
34
torch_to_path : Dict[str, Path]
92
- A mapping from PyTorch parameter name to the path of the file containing it, or the path
93
- meaning all parameters are stored in a single file.
35
+ A mapping from PyTorch/SafeTensor parameter name to the path of the file containing it,
36
+ or the path meaning all parameters are stored in a single file.
94
37
95
38
cached_files : Dict[Path, Dict[str, np.ndarray]]
96
39
A cache of the loaded files. The key is the path of the file, and the value is a mapping
@@ -113,20 +56,23 @@ def __init__(
113
56
----------
114
57
path : pathlib.Path
115
58
Path to either a JSON indexing file, or a PyTorch bin file.
116
- 1) For JSON indexing file, it is usually `pytorch_model.bin.index.json` in the repo,
117
- which contains a `weight_map` that maps each PyTorch parameter to the file containing
118
- the weight. 2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo,
59
+ 1) For JSON indexing file, it is usually `pytorch_model.bin.index.json`
60
+ or `model.safetensors.index.json` in the repo, which contains a `weight_map` that
61
+ maps each PyTorch parameter to the file containing the weight.
62
+ 2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo,
63
+ which contains all the parameters.
64
+ 3) For safetensor file, it is usually `model.safetensors` in the repo,
119
65
which contains all the parameters.
120
66
121
67
extern_param_map : ExternMapping
122
- Maps an MLC parameter to a list of PyTorch parameters.
68
+ Maps an MLC parameter to a list of PyTorch/SafeTensor parameters.
123
69
"""
124
70
assert path .is_file ()
125
71
self .stats = Stats ()
126
72
self .extern_param_map = extern_param_map
127
73
self .cached_files = {}
128
74
self .torch_to_path = {}
129
- if path .suffix == ".bin" :
75
+ if path .suffix in ( ".bin" , ".safetensors" ) :
130
76
self ._load_file (path )
131
77
for name in self .cached_files [path ].keys ():
132
78
self .torch_to_path [name ] = path
@@ -137,7 +83,7 @@ def __init__(
137
83
self .torch_to_path [torch_name ] = path .parent / path_str
138
84
else :
139
85
raise FileNotFoundError (f"Unknown file suffix: { path } " )
140
- _check_parameter_usage (extern_param_map , set (self .torch_to_path .keys ()))
86
+ check_parameter_usage (extern_param_map , set (self .torch_to_path .keys ()))
141
87
142
88
def load (self ) -> Iterator [Tuple [str , NDArray ]]:
143
89
"""Load the parameters and yield the MLC parameter and its value."""
@@ -148,21 +94,8 @@ def load(self) -> Iterator[Tuple[str, NDArray]]:
148
94
cached_files = list (self .cached_files .keys ())
149
95
for path in cached_files :
150
96
self ._unload_file (path )
151
-
152
- logger .info (
153
- "Time used: "
154
- "PyTorch loading: %.3f sec; "
155
- "Pre-quantization mapping: %.3f sec; "
156
- "Quantization: %.3f sec" ,
157
- self .stats .load_time_sec ,
158
- self .stats .map_time_sec ,
159
- self .stats .quant_time_sec ,
160
- )
161
- logger .info (
162
- "Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB" ,
163
- self .stats .total_memory_gb ,
164
- self .stats .max_memory_gb ,
165
- )
97
+ self .stats .log_time_info ("HF" )
98
+ self .stats .log_mem_usage ()
166
99
167
100
def _load_mlc_param (self , mlc_name : str ) -> np .ndarray :
168
101
torch_names = self .extern_param_map .param_map [mlc_name ]
@@ -190,53 +123,24 @@ def _load_mlc_param(self, mlc_name: str) -> np.ndarray:
190
123
return param
191
124
192
125
def _load_file (self , path : Path ) -> None :
193
- logger .info ("Loading PyTorch parameters from: %s" , path )
126
+ logger .info ("Loading HF parameters from: %s" , path )
127
+ load_func = load_safetensor_shard if path .suffix == ".safetensors" else load_torch_shard
194
128
with self .stats .timer ("load_time_sec" ):
195
129
result = {}
196
- for name , param in _load_torch_shard (path ):
130
+ for name , param in load_func (path ):
197
131
result [name ] = param
198
132
self .stats .mem_add (param .nbytes )
199
133
self .cached_files [path ] = result
200
134
201
135
def _unload_file (self , path : Path ) -> None :
202
- logger .info ("Unloading PyTorch weight file: %s" , path )
136
+ logger .info ("Unloading HF weight file: %s" , path )
203
137
with self .stats .timer ("load_time_sec" ):
204
138
for _ , param in self .cached_files [path ].items ():
205
139
self .stats .mem_rm (param .nbytes )
206
140
del self .cached_files [path ]
207
141
gc .collect ()
208
142
209
143
210
- def _check_parameter_usage (param_map : ExternMapping , torch_weights : Set [str ]):
211
- used_torch_names = set (sum (param_map .param_map .values (), ()))
212
- # Check 1. All PyTorch parameters in the weight files are used unless explicitly specified
213
- unused_torch_names = torch_weights - used_torch_names - param_map .unused_params
214
- if unused_torch_names :
215
- logger .warning (
216
- "Unused torch parameters: %s" ,
217
- ", " .join (sorted (unused_torch_names )),
218
- )
219
- # Check 2. All PyTorch parameters required are stored in the weight files
220
- nonexistent_torch_names = used_torch_names - torch_weights
221
- if nonexistent_torch_names :
222
- raise ValueError (
223
- "The following torch parameters do not exist in the weight files:\n "
224
- + "\n " .join (sorted (nonexistent_torch_names )),
225
- )
226
-
227
-
228
- def _load_torch_shard (path : Path ):
229
- import torch # pylint: disable=import-outside-toplevel
230
-
231
- for name , param in torch .load (path , map_location = torch .device ("cpu" )).items ():
232
- param = param .detach ().cpu ()
233
- dtype = str (param .dtype )
234
- if dtype == "torch.bfloat16" :
235
- param = param .float ()
236
- param = param .numpy ()
237
- yield name , param
238
-
239
-
240
144
def _loading_order (param_map : ExternMapping , torch_to_path : Dict [str , Path ]) -> List [str ]:
241
145
# Step 1. Build a map from path to torch parameters
242
146
path_to_torch : Dict [Path , List [str ]] = defaultdict (list )
@@ -257,4 +161,4 @@ def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) ->
257
161
return list (order .keys ())
258
162
259
163
260
- __all__ = ["HFTorchLoader " ]
164
+ __all__ = ["HFLoader " ]
0 commit comments