Skip to content

Commit e25717b

Browse files
committed
Update
1 parent 410d475 commit e25717b

File tree

8 files changed

+150
-79
lines changed

8 files changed

+150
-79
lines changed

python/mlc_chat/compiler/model/llama.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,18 @@
22
Implementation for Llama2 architecture.
33
TODO: add docstring
44
"""
5-
import dataclasses
65
import math
7-
from typing import Any, Dict, Optional
6+
from typing import Optional
87

98
from tvm import te, tir
109
from tvm.relax.frontend import nn
1110
from tvm.relax.frontend.nn import Tensor, op
1211

13-
from ...support.config import ConfigBase
12+
from .llama_config import LlamaConfig
1413

1514
# pylint: disable=invalid-name,missing-docstring
1615

1716

18-
@dataclasses.dataclass
19-
class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
20-
hidden_act: str
21-
hidden_size: int
22-
intermediate_size: int
23-
num_attention_heads: int
24-
num_hidden_layers: int
25-
rms_norm_eps: float
26-
vocab_size: int
27-
max_sequence_length: int = 2048
28-
position_embedding_base: int = 10000
29-
num_key_value_heads: int = 0
30-
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
31-
head_dim: int = 0
32-
33-
def __post_init__(self):
34-
if self.num_key_value_heads == 0:
35-
self.num_key_value_heads = self.num_attention_heads
36-
if self.head_dim == 0:
37-
self.head_dim = self.hidden_size // self.num_attention_heads
38-
assert self.num_attention_heads % self.num_key_value_heads == 0
39-
assert self.head_dim * self.num_attention_heads == self.hidden_size
40-
41-
4217
class RotaryEmbedding(nn.Module):
4318
def __init__(self, config: LlamaConfig):
4419
super().__init__()

python/mlc_chat/compiler/model/llama_config.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,40 @@
11
"""Common configuration for Llama models."""
2+
import dataclasses
3+
from typing import Any, Dict
4+
5+
from ...support.config import ConfigBase
6+
7+
8+
@dataclasses.dataclass
9+
class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
10+
"""Configuration of the Llama model."""
11+
12+
hidden_act: str
13+
hidden_size: int
14+
intermediate_size: int
15+
num_attention_heads: int
16+
num_hidden_layers: int
17+
rms_norm_eps: float
18+
vocab_size: int
19+
max_sequence_length: int = 2048
20+
position_embedding_base: int = 10000
21+
num_key_value_heads: int = 0
22+
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
23+
head_dim: int = 0
24+
25+
def __post_init__(self):
26+
if self.num_key_value_heads == 0:
27+
self.num_key_value_heads = self.num_attention_heads
28+
if self.head_dim == 0:
29+
self.head_dim = self.hidden_size // self.num_attention_heads
30+
assert self.num_attention_heads % self.num_key_value_heads == 0
31+
assert self.head_dim * self.num_attention_heads == self.hidden_size
32+
33+
@staticmethod
34+
def from_predefined(name: str) -> "LlamaConfig":
35+
"""Create a LlamaConfig from a predefined configuration."""
36+
return LlamaConfig.from_dict(CONFIG[name])
37+
238

339
CONFIG = {
440
"llama2_7b": {

python/mlc_chat/compiler/parameter/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
A subpackage of the compiler that represents mapping between external parameters, quantized
33
parameters and parameters in MLC-defined models.
44
"""
5+
from .hf_torch_loader import HFTorchLoader
56
from .mapping import ExternMapping, QuantizeMapping

python/mlc_chat/compiler/parameter/hf_torch_loader.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
import json
55
import logging
66
import time
7-
from collections import defaultdict
7+
from collections import OrderedDict, defaultdict
88
from contextlib import contextmanager
99
from pathlib import Path
1010
from typing import Dict, Iterator, List, Set, Tuple
1111

1212
import numpy as np
1313
from tqdm import tqdm
14-
from tqdm.contrib.logging import logging_redirect_tqdm
1514
from tvm.runtime import NDArray
15+
from tvm.runtime.ndarray import array as as_ndarray
1616

17-
from .mapping import ExternMapping, QuantizeMapping
17+
from .mapping import ExternMapping
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -140,22 +140,32 @@ def __init__(
140140
_check_parameter_usage(extern_param_map, set(self.torch_to_path.keys()))
141141

142142
def load(self) -> Iterator[Tuple[str, NDArray]]:
143+
"""Load the parameters and yield the MLC parameter and its value."""
143144
mlc_names = _loading_order(self.extern_param_map, self.torch_to_path)
144-
with logging_redirect_tqdm():
145-
for mlc_name in tqdm(mlc_names):
146-
param = self._load_mlc_param(mlc_name)
147-
yield mlc_name, param
145+
for mlc_name in tqdm(mlc_names):
146+
param = self._load_mlc_param(mlc_name)
147+
yield mlc_name, param
148148
cached_files = list(self.cached_files.keys())
149149
for path in cached_files:
150150
self._unload_file(path)
151-
# logger.info(
152-
# "Time used in PyTorch loading: %.3f sec. Total %.3f GB loaded",
153-
# self.stats_load_time_sec,
154-
# self.stats_load_data_gb,
155-
# )
151+
152+
logger.info(
153+
"Time used: "
154+
"PyTorch loading: %.3f sec; "
155+
"Pre-quantization mapping: %.3f sec; "
156+
"Quantization: %.3f sec",
157+
self.stats.load_time_sec,
158+
self.stats.map_time_sec,
159+
self.stats.quant_time_sec,
160+
)
161+
logger.info(
162+
"Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB",
163+
self.stats.total_memory_gb,
164+
self.stats.max_memory_gb,
165+
)
156166

157167
def _load_mlc_param(self, mlc_name: str) -> np.ndarray:
158-
torch_names = self.extern_param_map.name_map[mlc_name]
168+
torch_names = self.extern_param_map.param_map[mlc_name]
159169
files_required = {self.torch_to_path[p] for p in torch_names}
160170
files_existing = set(self.cached_files.keys())
161171
files_to_load = files_required - files_existing
@@ -176,6 +186,7 @@ def _load_mlc_param(self, mlc_name: str) -> np.ndarray:
176186
with self.stats.timer("map_time_sec"):
177187
param = self.extern_param_map.map_func[mlc_name](*torch_params)
178188
logger.info(' Parameter: "%s", shape: %s, dtype: %s', mlc_name, param.shape, param.dtype)
189+
param = as_ndarray(param)
179190
return param
180191

181192
def _load_file(self, path: Path) -> None:
@@ -197,7 +208,7 @@ def _unload_file(self, path: Path) -> None:
197208

198209

199210
def _check_parameter_usage(param_map: ExternMapping, torch_weights: Set[str]):
200-
used_torch_names = set(sum(param_map.name_map.values(), ()))
211+
used_torch_names = set(sum(param_map.param_map.values(), ()))
201212
# Check 1. All PyTorch parameters in the weight files are used unless explicitly specified
202213
unused_torch_names = torch_weights - used_torch_names - param_map.unused_params
203214
if unused_torch_names:
@@ -233,16 +244,17 @@ def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) ->
233244
path_to_torch[path].append(torch_name)
234245
# Step 2. Build a map from torch parameters to MLC parameters
235246
torch_to_mlc = defaultdict(list)
236-
for mlc_name, torch_names in param_map.name_map.items():
247+
for mlc_name, torch_names in param_map.param_map.items():
237248
for torch_name in torch_names:
238249
torch_to_mlc[torch_name].append(mlc_name)
239250
# Step 3. Construct the ordering that ensures file locality
240-
order = []
251+
order = OrderedDict()
241252
for _, torch_names in path_to_torch.items():
242253
for torch_name in torch_names:
243254
for mlc_name in torch_to_mlc[torch_name]:
244-
order.append(mlc_name)
245-
return order
255+
if mlc_name not in order:
256+
order[mlc_name] = 1
257+
return list(order.keys())
246258

247259

248260
__all__ = ["HFTorchLoader"]

python/mlc_chat/support/tqdm.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""Utils to better use tqdm"""
2+
import contextlib
3+
import inspect
4+
import io
5+
6+
from tqdm import tqdm
7+
from tqdm.contrib.logging import logging_redirect_tqdm as _redirect_logging
8+
9+
10+
@contextlib.contextmanager
11+
def _redirect_print():
12+
old_print = print
13+
14+
def new_print(*args, **kwargs):
15+
with io.StringIO() as output:
16+
kwargs["file"] = output
17+
kwargs["end"] = ""
18+
old_print(*args, **kwargs)
19+
content = output.getvalue()
20+
tqdm.write(content)
21+
22+
try:
23+
inspect.builtins.print = new_print
24+
yield
25+
finally:
26+
inspect.builtins.print = old_print
27+
28+
29+
@contextlib.contextmanager
30+
def redirect():
31+
"""Redirect tqdm output to logging and print."""
32+
33+
with _redirect_logging():
34+
with _redirect_print():
35+
yield
36+
37+
38+
__all__ = ["tqdm", "redirect"]

tests/python/model/test_llama.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# pylint: disable=invalid-name,missing-docstring
22
import pytest
33
from mlc_chat.compiler.model.llama import LlamaConfig, LlamaForCasualLM
4-
from mlc_chat.compiler.model.llama_config import CONFIG
54

65

76
@pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"])
87
def test_llama2_creation(model_name: str):
9-
config = LlamaConfig.from_dict(CONFIG[model_name])
8+
config = LlamaConfig.from_predefined(model_name)
109
model = LlamaForCasualLM(config)
1110
mod, named_params = model.export_tvm(spec=model.get_default_spec())
1211
mod.show(black_format=False)

tests/python/parameter/hf_torch_loader.py

Lines changed: 0 additions & 32 deletions
This file was deleted.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# pylint: disable=missing-docstring
2+
import logging
3+
from pathlib import Path
4+
5+
import pytest
6+
from mlc_chat.compiler.model.llama import LlamaConfig
7+
from mlc_chat.compiler.model.llama_parameter import hf_torch
8+
from mlc_chat.compiler.parameter import HFTorchLoader
9+
from mlc_chat.support import tqdm
10+
11+
logging.basicConfig(
12+
level=logging.DEBUG,
13+
style="{",
14+
datefmt="%Y-%m-%d %H:%M:%S",
15+
format="[{asctime}] {levelname} {filename}:{lineno}: {message}",
16+
)
17+
18+
19+
@pytest.mark.parametrize(
20+
"base_path",
21+
[
22+
"./dist/models/Llama-2-7b-hf",
23+
"./dist/models/Llama-2-13b-hf",
24+
"./dist/models/Llama-2-70b-hf",
25+
],
26+
)
27+
def test_load_llama(base_path: str):
28+
base_path = Path(base_path)
29+
path_config = base_path / "config.json"
30+
path_params = base_path / "pytorch_model.bin.index.json"
31+
32+
config = LlamaConfig.from_file(path_config)
33+
loader = HFTorchLoader(path=path_params, extern_param_map=hf_torch(config))
34+
with tqdm.redirect():
35+
for _name, _param in loader.load():
36+
...
37+
38+
39+
if __name__ == "__main__":
40+
test_load_llama(base_path="./dist/models/Llama-2-7b-hf")
41+
test_load_llama(base_path="./dist/models/Llama-2-13b-hf")
42+
test_load_llama(base_path="./dist/models/Llama-2-70b-hf")

0 commit comments

Comments
 (0)