4
4
import json
5
5
import logging
6
6
import time
7
- from collections import defaultdict
7
+ from collections import OrderedDict , defaultdict
8
8
from contextlib import contextmanager
9
9
from pathlib import Path
10
10
from typing import Dict , Iterator , List , Set , Tuple
11
11
12
12
import numpy as np
13
13
from tqdm import tqdm
14
- from tqdm .contrib .logging import logging_redirect_tqdm
15
14
from tvm .runtime import NDArray
15
+ from tvm .runtime .ndarray import array as as_ndarray
16
16
17
- from .mapping import ExternMapping , QuantizeMapping
17
+ from .mapping import ExternMapping
18
18
19
19
logger = logging .getLogger (__name__ )
20
20
@@ -140,22 +140,32 @@ def __init__(
140
140
_check_parameter_usage (extern_param_map , set (self .torch_to_path .keys ()))
141
141
142
142
def load (self ) -> Iterator [Tuple [str , NDArray ]]:
143
+ """Load the parameters and yield the MLC parameter and its value."""
143
144
mlc_names = _loading_order (self .extern_param_map , self .torch_to_path )
144
- with logging_redirect_tqdm ():
145
- for mlc_name in tqdm (mlc_names ):
146
- param = self ._load_mlc_param (mlc_name )
147
- yield mlc_name , param
145
+ for mlc_name in tqdm (mlc_names ):
146
+ param = self ._load_mlc_param (mlc_name )
147
+ yield mlc_name , param
148
148
cached_files = list (self .cached_files .keys ())
149
149
for path in cached_files :
150
150
self ._unload_file (path )
151
- # logger.info(
152
- # "Time used in PyTorch loading: %.3f sec. Total %.3f GB loaded",
153
- # self.stats_load_time_sec,
154
- # self.stats_load_data_gb,
155
- # )
151
+
152
+ logger .info (
153
+ "Time used: "
154
+ "PyTorch loading: %.3f sec; "
155
+ "Pre-quantization mapping: %.3f sec; "
156
+ "Quantization: %.3f sec" ,
157
+ self .stats .load_time_sec ,
158
+ self .stats .map_time_sec ,
159
+ self .stats .quant_time_sec ,
160
+ )
161
+ logger .info (
162
+ "Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB" ,
163
+ self .stats .total_memory_gb ,
164
+ self .stats .max_memory_gb ,
165
+ )
156
166
157
167
def _load_mlc_param (self , mlc_name : str ) -> np .ndarray :
158
- torch_names = self .extern_param_map .name_map [mlc_name ]
168
+ torch_names = self .extern_param_map .param_map [mlc_name ]
159
169
files_required = {self .torch_to_path [p ] for p in torch_names }
160
170
files_existing = set (self .cached_files .keys ())
161
171
files_to_load = files_required - files_existing
@@ -176,6 +186,7 @@ def _load_mlc_param(self, mlc_name: str) -> np.ndarray:
176
186
with self .stats .timer ("map_time_sec" ):
177
187
param = self .extern_param_map .map_func [mlc_name ](* torch_params )
178
188
logger .info (' Parameter: "%s", shape: %s, dtype: %s' , mlc_name , param .shape , param .dtype )
189
+ param = as_ndarray (param )
179
190
return param
180
191
181
192
def _load_file (self , path : Path ) -> None :
@@ -197,7 +208,7 @@ def _unload_file(self, path: Path) -> None:
197
208
198
209
199
210
def _check_parameter_usage (param_map : ExternMapping , torch_weights : Set [str ]):
200
- used_torch_names = set (sum (param_map .name_map .values (), ()))
211
+ used_torch_names = set (sum (param_map .param_map .values (), ()))
201
212
# Check 1. All PyTorch parameters in the weight files are used unless explicitly specified
202
213
unused_torch_names = torch_weights - used_torch_names - param_map .unused_params
203
214
if unused_torch_names :
@@ -233,16 +244,17 @@ def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) ->
233
244
path_to_torch [path ].append (torch_name )
234
245
# Step 2. Build a map from torch parameters to MLC parameters
235
246
torch_to_mlc = defaultdict (list )
236
- for mlc_name , torch_names in param_map .name_map .items ():
247
+ for mlc_name , torch_names in param_map .param_map .items ():
237
248
for torch_name in torch_names :
238
249
torch_to_mlc [torch_name ].append (mlc_name )
239
250
# Step 3. Construct the ordering that ensures file locality
240
- order = []
251
+ order = OrderedDict ()
241
252
for _ , torch_names in path_to_torch .items ():
242
253
for torch_name in torch_names :
243
254
for mlc_name in torch_to_mlc [torch_name ]:
244
- order .append (mlc_name )
245
- return order
255
+ if mlc_name not in order :
256
+ order [mlc_name ] = 1
257
+ return list (order .keys ())
246
258
247
259
248
260
__all__ = ["HFTorchLoader" ]
0 commit comments