Skip to content

Commit 7ae8c6d

Browse files
authored
[Slim-LM] Introduce HFLoad for loading Pytorch and SafeTensor weights (mlc-ai#1113)
1 parent e5927ce commit 7ae8c6d

File tree

6 files changed

+231
-166
lines changed

6 files changed

+231
-166
lines changed

python/mlc_chat/compiler/parameter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
A subpackage of the compiler that represents mapping between external parameters, quantized
33
parameters and parameters in MLC-defined models.
44
"""
5-
from .hf_torch_loader import HFTorchLoader
5+
from .hf_loader import HFLoader
66
from .mapping import ExternMapping, QuantizeMapping
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,39 @@
11
"""A weight loader for HuggingFace's PyTorch format"""
2-
import dataclasses
2+
33
import gc
44
import json
55
import logging
6-
import time
76
from collections import OrderedDict, defaultdict
8-
from contextlib import contextmanager
97
from pathlib import Path
10-
from typing import Dict, Iterator, List, Set, Tuple
8+
from typing import Dict, Iterator, List, Tuple
119

1210
import numpy as np
1311
from tqdm import tqdm
1412
from tvm.runtime import NDArray
1513
from tvm.runtime.ndarray import array as as_ndarray
1614

1715
from .mapping import ExternMapping
16+
from .stats import Stats
17+
from .utils import check_parameter_usage, load_safetensor_shard, load_torch_shard
1818

1919
logger = logging.getLogger(__name__)
2020

2121

22-
@dataclasses.dataclass
23-
class Stats:
24-
"""Statistics of the loading process of HuggingFace PyTorch loader.
25-
26-
Attributes
27-
----------
28-
load_time_sec : float
29-
Time used in loading the parameters.
30-
31-
map_time_sec : float
32-
Time used in applying the mapping function, i.e. `ExternMapping.map_func`.
33-
34-
quant_time_sec : float
35-
Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`.
36-
37-
current_memory_gb : float
38-
The current RAM usage in GB.
39-
40-
total_memory_gb : float
41-
The total size data loaded from disk in GB.
42-
43-
max_memory_gb : float
44-
The maximum RAM usage in GB.
45-
"""
46-
47-
load_time_sec: float = 0.0
48-
map_time_sec: float = 0.0
49-
quant_time_sec: float = 0.0
50-
51-
current_memory_gb: float = 0.0
52-
total_memory_gb: float = 0.0
53-
max_memory_gb: float = 0.0
54-
55-
def timer(self, attr):
56-
"""A context manager to time the scope and add the time to the attribute."""
57-
58-
@contextmanager
59-
def timed_scope():
60-
start_time = time.time()
61-
yield
62-
elapsed_time = time.time() - start_time
63-
setattr(self, attr, getattr(self, attr) + elapsed_time)
64-
65-
return timed_scope()
66-
67-
def mem_add(self, nbytes: int):
68-
"""Add the memory usage by the given number of bytes."""
69-
mem_gb = float(nbytes) / float(1024**3)
70-
self.current_memory_gb += mem_gb
71-
self.total_memory_gb += mem_gb
72-
self.max_memory_gb = max(self.max_memory_gb, self.current_memory_gb)
73-
74-
def mem_rm(self, nbytes: int):
75-
"""Remove the memory usage by the given number of bytes."""
76-
mem_gb = float(nbytes) / float(1024**3)
77-
self.current_memory_gb -= mem_gb
78-
79-
80-
class HFTorchLoader: # pylint: disable=too-few-public-methods
81-
"""A loader loading HuggingFace's PyTorch format and converts them to MLC's parameters.
22+
class HFLoader: # pylint: disable=too-few-public-methods
23+
"""A loader loading HuggingFace's PyTorch/SafeTensor format and converts them
24+
to MLC's parameters.
8225
8326
Attributes
8427
----------
8528
stats : Stats
8629
Statistics of the loading process.
8730
8831
extern_param_map : ExternMapping
89-
The parameter mapping from MLC to HuggingFace PyTorch.
32+
The parameter mapping from MLC to HuggingFace PyTorch/SafeTensor.
9033
9134
torch_to_path : Dict[str, Path]
92-
A mapping from PyTorch parameter name to the path of the file containing it, or the path
93-
meaning all parameters are stored in a single file.
35+
A mapping from PyTorch/SafeTensor parameter name to the path of the file containing it,
36+
or the path meaning all parameters are stored in a single file.
9437
9538
cached_files : Dict[Path, Dict[str, np.ndarray]]
9639
A cache of the loaded files. The key is the path of the file, and the value is a mapping
@@ -113,20 +56,23 @@ def __init__(
11356
----------
11457
path : pathlib.Path
11558
Path to either a JSON indexing file, or a PyTorch bin file.
116-
1) For JSON indexing file, it is usually `pytorch_model.bin.index.json` in the repo,
117-
which contains a `weight_map` that maps each PyTorch parameter to the file containing
118-
the weight. 2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo,
59+
1) For JSON indexing file, it is usually `pytorch_model.bin.index.json`
60+
or `model.safetensors.index.json` in the repo, which contains a `weight_map` that
61+
maps each PyTorch parameter to the file containing the weight.
62+
2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo,
63+
which contains all the parameters.
64+
3) For safetensor file, it is usually `model.safetensors` in the repo,
11965
which contains all the parameters.
12066
12167
extern_param_map : ExternMapping
122-
Maps an MLC parameter to a list of PyTorch parameters.
68+
Maps an MLC parameter to a list of PyTorch/SafeTensor parameters.
12369
"""
12470
assert path.is_file()
12571
self.stats = Stats()
12672
self.extern_param_map = extern_param_map
12773
self.cached_files = {}
12874
self.torch_to_path = {}
129-
if path.suffix == ".bin":
75+
if path.suffix in (".bin", ".safetensors"):
13076
self._load_file(path)
13177
for name in self.cached_files[path].keys():
13278
self.torch_to_path[name] = path
@@ -137,7 +83,7 @@ def __init__(
13783
self.torch_to_path[torch_name] = path.parent / path_str
13884
else:
13985
raise FileNotFoundError(f"Unknown file suffix: {path}")
140-
_check_parameter_usage(extern_param_map, set(self.torch_to_path.keys()))
86+
check_parameter_usage(extern_param_map, set(self.torch_to_path.keys()))
14187

14288
def load(self) -> Iterator[Tuple[str, NDArray]]:
14389
"""Load the parameters and yield the MLC parameter and its value."""
@@ -148,21 +94,8 @@ def load(self) -> Iterator[Tuple[str, NDArray]]:
14894
cached_files = list(self.cached_files.keys())
14995
for path in cached_files:
15096
self._unload_file(path)
151-
152-
logger.info(
153-
"Time used: "
154-
"PyTorch loading: %.3f sec; "
155-
"Pre-quantization mapping: %.3f sec; "
156-
"Quantization: %.3f sec",
157-
self.stats.load_time_sec,
158-
self.stats.map_time_sec,
159-
self.stats.quant_time_sec,
160-
)
161-
logger.info(
162-
"Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB",
163-
self.stats.total_memory_gb,
164-
self.stats.max_memory_gb,
165-
)
97+
self.stats.log_time_info("HF")
98+
self.stats.log_mem_usage()
16699

167100
def _load_mlc_param(self, mlc_name: str) -> np.ndarray:
168101
torch_names = self.extern_param_map.param_map[mlc_name]
@@ -190,53 +123,24 @@ def _load_mlc_param(self, mlc_name: str) -> np.ndarray:
190123
return param
191124

192125
def _load_file(self, path: Path) -> None:
193-
logger.info("Loading PyTorch parameters from: %s", path)
126+
logger.info("Loading HF parameters from: %s", path)
127+
load_func = load_safetensor_shard if path.suffix == ".safetensors" else load_torch_shard
194128
with self.stats.timer("load_time_sec"):
195129
result = {}
196-
for name, param in _load_torch_shard(path):
130+
for name, param in load_func(path):
197131
result[name] = param
198132
self.stats.mem_add(param.nbytes)
199133
self.cached_files[path] = result
200134

201135
def _unload_file(self, path: Path) -> None:
202-
logger.info("Unloading PyTorch weight file: %s", path)
136+
logger.info("Unloading HF weight file: %s", path)
203137
with self.stats.timer("load_time_sec"):
204138
for _, param in self.cached_files[path].items():
205139
self.stats.mem_rm(param.nbytes)
206140
del self.cached_files[path]
207141
gc.collect()
208142

209143

210-
def _check_parameter_usage(param_map: ExternMapping, torch_weights: Set[str]):
211-
used_torch_names = set(sum(param_map.param_map.values(), ()))
212-
# Check 1. All PyTorch parameters in the weight files are used unless explicitly specified
213-
unused_torch_names = torch_weights - used_torch_names - param_map.unused_params
214-
if unused_torch_names:
215-
logger.warning(
216-
"Unused torch parameters: %s",
217-
", ".join(sorted(unused_torch_names)),
218-
)
219-
# Check 2. All PyTorch parameters required are stored in the weight files
220-
nonexistent_torch_names = used_torch_names - torch_weights
221-
if nonexistent_torch_names:
222-
raise ValueError(
223-
"The following torch parameters do not exist in the weight files:\n "
224-
+ "\n ".join(sorted(nonexistent_torch_names)),
225-
)
226-
227-
228-
def _load_torch_shard(path: Path):
229-
import torch # pylint: disable=import-outside-toplevel
230-
231-
for name, param in torch.load(path, map_location=torch.device("cpu")).items():
232-
param = param.detach().cpu()
233-
dtype = str(param.dtype)
234-
if dtype == "torch.bfloat16":
235-
param = param.float()
236-
param = param.numpy()
237-
yield name, param
238-
239-
240144
def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) -> List[str]:
241145
# Step 1. Build a map from path to torch parameters
242146
path_to_torch: Dict[Path, List[str]] = defaultdict(list)
@@ -257,4 +161,4 @@ def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) ->
257161
return list(order.keys())
258162

259163

260-
__all__ = ["HFTorchLoader"]
164+
__all__ = ["HFLoader"]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Statistics of the loading process of parameter loaders"""
2+
import dataclasses
3+
import logging
4+
import time
5+
from contextlib import contextmanager
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
@dataclasses.dataclass
11+
class Stats:
12+
"""Statistics of the loading process of parameter loaders.
13+
14+
Attributes
15+
----------
16+
load_time_sec : float
17+
Time used in loading the parameters.
18+
19+
map_time_sec : float
20+
Time used in applying the mapping function, i.e. `ExternMapping.map_func`.
21+
22+
quant_time_sec : float
23+
Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`.
24+
25+
current_memory_gb : float
26+
The current RAM usage in GB.
27+
28+
total_memory_gb : float
29+
The total size data loaded from disk in GB.
30+
31+
max_memory_gb : float
32+
The maximum RAM usage in GB.
33+
"""
34+
35+
load_time_sec: float = 0.0
36+
map_time_sec: float = 0.0
37+
quant_time_sec: float = 0.0
38+
39+
current_memory_gb: float = 0.0
40+
total_memory_gb: float = 0.0
41+
max_memory_gb: float = 0.0
42+
43+
def timer(self, attr):
44+
"""A context manager to time the scope and add the time to the attribute."""
45+
46+
@contextmanager
47+
def timed_scope():
48+
start_time = time.time()
49+
yield
50+
elapsed_time = time.time() - start_time
51+
setattr(self, attr, getattr(self, attr) + elapsed_time)
52+
53+
return timed_scope()
54+
55+
def mem_add(self, nbytes: int):
56+
"""Add the memory usage by the given number of bytes."""
57+
mem_gb = float(nbytes) / float(1024**3)
58+
self.current_memory_gb += mem_gb
59+
self.total_memory_gb += mem_gb
60+
self.max_memory_gb = max(self.max_memory_gb, self.current_memory_gb)
61+
62+
def mem_rm(self, nbytes: int):
63+
"""Remove the memory usage by the given number of bytes."""
64+
mem_gb = float(nbytes) / float(1024**3)
65+
self.current_memory_gb -= mem_gb
66+
67+
def log_time_info(self, weight_format: str):
68+
"""Log the time used in loading, pre-quantization and quantization."""
69+
logger.info(
70+
"Time used: "
71+
"%s loading: %.3f sec; "
72+
"Pre-quantization mapping: %.3f sec; "
73+
"Quantization: %.3f sec",
74+
weight_format,
75+
self.load_time_sec,
76+
self.map_time_sec,
77+
self.quant_time_sec,
78+
)
79+
80+
def log_mem_usage(self):
81+
"""Log the Memory usage information."""
82+
logger.info(
83+
"Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB",
84+
self.total_memory_gb,
85+
self.max_memory_gb,
86+
)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Common utilities for loading parameters"""
2+
import logging
3+
from pathlib import Path
4+
from typing import Iterator, Set, Tuple
5+
6+
import numpy as np
7+
8+
from .mapping import ExternMapping
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def check_parameter_usage(param_map: ExternMapping, extern_weights: Set[str]):
14+
"""Check that all external parameters have been used and are stored in the weights file."""
15+
used_extern_names = set(sum(param_map.param_map.values(), []))
16+
# Check 1. All extern parameters in the weight files are used unless explicitly specified
17+
unused_extern_names = extern_weights - used_extern_names - param_map.unused_params
18+
if unused_extern_names:
19+
logger.warning(
20+
"Unused extern parameters: %s",
21+
", ".join(sorted(unused_extern_names)),
22+
)
23+
# Check 2. All extern parameters required are stored in the weight files
24+
nonexistent_extern_names = used_extern_names - extern_weights
25+
if nonexistent_extern_names:
26+
raise ValueError(
27+
"The following extern parameters do not exist in the weight files:\n "
28+
+ "\n ".join(sorted(nonexistent_extern_names)),
29+
)
30+
31+
32+
def load_torch_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]:
33+
"""Load and yield PyTorch format parameters."""
34+
import torch # pylint: disable=import-outside-toplevel
35+
36+
for name, param in torch.load(path, map_location=torch.device("cpu")).items():
37+
param = param.detach().cpu()
38+
dtype = str(param.dtype)
39+
if dtype == "torch.bfloat16":
40+
param = param.float()
41+
param = param.numpy()
42+
yield name, param
43+
44+
45+
def load_safetensor_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]:
46+
"""Load and yield SafeTensor format parameters."""
47+
import safetensors # pylint: disable=import-outside-toplevel,import-error
48+
49+
with safetensors.safe_open(path, framework="numpy", device="cpu") as in_file:
50+
for name in in_file.keys():
51+
param = in_file.get_tensor(name)
52+
yield name, param

0 commit comments

Comments
 (0)