Skip to content

Commit 4042626

Browse files
authored
[Slim-LM] Enable loading from AWQ pre-quantized weight. (mlc-ai#1114)
* [SLM] Enable loading from AWQ pre-quantized weight. * remove awq_loader.py * Update to the latest commit * Delete llama_parameter.py * update unittest * fix lint * upd * add Llama-2-7B-AWQ
1 parent 9869ca6 commit 4042626

File tree

11 files changed

+650
-14
lines changed

11 files changed

+650
-14
lines changed

python/mlc_chat/compiler/loader/huggingface_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
self.cached_files = {}
8484
self.torch_to_path = {}
8585
self.quantize_param_map = quantize_param_map
86-
if path.suffix in (".bin", ".safetensors"):
86+
if path.suffix in (".bin", ".safetensors", ".pt"):
8787
self._load_file(path)
8888
for name in self.cached_files[path].keys():
8989
self.torch_to_path[name] = path

python/mlc_chat/compiler/model/llama_loader.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..loader import ExternMapping
1010
from ..quantization import Quantization
1111
from .llama_model import LlamaConfig, LlamaForCasualLM
12+
from .llama_quantization import awq_quant
1213

1314

1415
def huggingface(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping:
@@ -82,3 +83,76 @@ def huggingface(model_config: LlamaConfig, quantization: Quantization) -> Extern
8283
),
8384
)
8485
return mapping
86+
87+
88+
def awq(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping:
89+
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
90+
the names of AWQ parameters.
91+
Parameters
92+
----------
93+
model_config : LlamaConfig
94+
The configuration of the Llama model.
95+
96+
quantization : Quantization
97+
The quantization configuration.
98+
99+
Returns
100+
-------
101+
param_map : ExternMapping
102+
The parameter mapping from MLC to AWQ.
103+
"""
104+
model, _ = awq_quant(model_config, quantization)
105+
_, _named_params = model.export_tvm(spec=model.get_default_spec())
106+
named_parameters = dict(_named_params)
107+
108+
mapping = ExternMapping()
109+
110+
for i in range(model_config.num_hidden_layers):
111+
# Add QKV in self attention
112+
attn = f"model.layers.{i}.self_attn"
113+
for quantize_suffix in ["qweight", "qzeros", "scales"]:
114+
mlc_name = f"{attn}.qkv_proj.{quantize_suffix}"
115+
assert mlc_name in named_parameters
116+
mlc_param = named_parameters[mlc_name]
117+
mapping.add_mapping(
118+
mlc_name,
119+
[
120+
f"{attn}.q_proj.{quantize_suffix}",
121+
f"{attn}.k_proj.{quantize_suffix}",
122+
f"{attn}.v_proj.{quantize_suffix}",
123+
],
124+
functools.partial(
125+
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
126+
dtype=mlc_param.dtype,
127+
),
128+
)
129+
130+
# Concat gate and up in MLP
131+
mlp = f"model.layers.{i}.mlp"
132+
for quantize_suffix in ["qweight", "qzeros", "scales"]:
133+
mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}"
134+
assert mlc_name in named_parameters
135+
mlc_param = named_parameters[mlc_name]
136+
mapping.add_mapping(
137+
mlc_name,
138+
[
139+
f"{mlp}.gate_proj.{quantize_suffix}",
140+
f"{mlp}.up_proj.{quantize_suffix}",
141+
],
142+
functools.partial(
143+
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
144+
dtype=mlc_param.dtype,
145+
),
146+
)
147+
148+
# inv_freq is not used in the model
149+
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")
150+
151+
for mlc_name, mlc_param in named_parameters.items():
152+
if mlc_name not in mapping.param_map:
153+
mapping.add_mapping(
154+
mlc_name,
155+
[mlc_name],
156+
functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype),
157+
)
158+
return mapping

python/mlc_chat/compiler/model/llama_quantization.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tvm.relax.frontend import nn
66

77
from ..loader import QuantizeMapping
8-
from ..quantization import GroupQuantize
8+
from ..quantization import AWQQuantize, GroupQuantize
99
from .llama_model import LlamaConfig, LlamaForCasualLM
1010

1111

@@ -15,6 +15,23 @@ def group_quant(
1515
) -> Tuple[nn.Module, QuantizeMapping]:
1616
"""Quantize a Llama2 model using group quantization."""
1717
model: nn.Module = LlamaForCasualLM(model_config)
18+
model.to(quantization.model_dtype)
19+
quant_map = QuantizeMapping({}, {})
20+
model = quantization.quantize_model(
21+
model,
22+
quant_map,
23+
"",
24+
)
25+
return model, quant_map
26+
27+
28+
def awq_quant(
29+
model_config: LlamaConfig,
30+
quantization: AWQQuantize,
31+
) -> Tuple[nn.Module, QuantizeMapping]:
32+
"""Quantize a Llama2 model using Activation-aware Weight Quantization(AWQ)."""
33+
model: nn.Module = LlamaForCasualLM(model_config)
34+
model.to(quantization.model_dtype)
1835
quant_map = QuantizeMapping({}, {})
1936
model = quantization.quantize_model(
2037
model,

python/mlc_chat/compiler/model/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class Model:
5858
source={
5959
"huggingface-torch": llama_loader.huggingface,
6060
"huggingface-safetensor": llama_loader.huggingface,
61+
"awq": llama_loader.awq,
6162
},
6263
quantize={
6364
"group-quant": llama_quantization.group_quant,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
"""A subpackage for quantization and dequantization algorithms"""
2+
from .awq_quantization import AWQQuantize
23
from .group_quantization import GroupQuantize
34
from .quantization import QUANTIZATION, Quantization

0 commit comments

Comments
 (0)