Skip to content

Commit b5bfa5b

Browse files
cyx-6junrushao
andauthored
Enable group quant transform with nn.Module (mlc-ai#1154)
* Enable group quant transform with nn.Module This PR completes the group quantization support for `nn.Module` based model. * remove deprecated tests * Update * wip * remove deprecated test * fix lint * fix lint * fix lint --------- Co-authored-by: Junru Shao <[email protected]>
1 parent 9076d01 commit b5bfa5b

File tree

12 files changed

+388
-349
lines changed

12 files changed

+388
-349
lines changed

python/mlc_chat/cli/compile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from mlc_chat.compiler import ( # pylint: disable=redefined-builtin
88
MODELS,
9-
QUANT,
9+
QUANTIZATION,
1010
OptimizationFlags,
1111
compile,
1212
)
@@ -51,7 +51,7 @@ def _parse_output(path: Union[str, Path]) -> Path:
5151
"--quantization",
5252
type=str,
5353
required=True,
54-
choices=list(QUANT.keys()),
54+
choices=list(QUANTIZATION.keys()),
5555
help="Quantization format.",
5656
)
5757
parser.add_argument(
@@ -119,7 +119,7 @@ def _parse_output(path: Union[str, Path]) -> Path:
119119
parsed.model_type = detect_model_type(parsed.model_type, parsed.config)
120120
compile(
121121
config=parsed.config,
122-
quantization=parsed.quantization,
122+
quantization=QUANTIZATION[parsed.quantization],
123123
model_type=parsed.model_type,
124124
target=target,
125125
opt=parsed.opt,

python/mlc_chat/compiler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
from .flags_optimization import OptimizationFlags
88
from .model import MODEL_PRESETS, MODELS, Model
99
from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping
10-
from .quantization import QUANT
10+
from .quantization import QUANTIZATION

python/mlc_chat/compiler/compile.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@
77
from tvm import IRModule, relax
88
from tvm.target import Target
99

10-
from ..compiler.model import Model
1110
from ..support.style import bold
1211
from .flags_optimization import OptimizationFlags
12+
from .model import Model
13+
from .quantization import Quantization
1314

1415

1516
@dataclasses.dataclass
1617
class CompileArgs: # pylint: disable=too-many-instance-attributes
1718
"""Arguments to MLC LLM's compiler."""
1819

1920
config: Path
20-
quantization: str
21+
quantization: Quantization
2122
model: Model
2223
target: Target
2324
opt: OptimizationFlags
@@ -40,20 +41,19 @@ def _echo_args(args: CompileArgs) -> None:
4041

4142
def _compile(args: CompileArgs):
4243
model_config = args.model.config.from_file(args.config)
43-
model = args.model.model(model_config)
44-
mod, named_params = model.export_tvm(
44+
quantization = args.quantization
45+
model, _ = args.model.quantize[quantization.kind](model_config, quantization)
46+
mod, _named_params = model.export_tvm(
4547
spec=model.get_default_spec(), # type: ignore
4648
)
4749
with args.target:
4850
mod = relax.get_pipeline("mlc_llm")(mod)
49-
mod.show(black_format=False)
50-
for name, param in named_params:
51-
print(f"{name}: {param.shape} {param.dtype}")
51+
args.build_func(mod, args)
5252

5353

5454
def compile( # pylint: disable=too-many-arguments,redefined-builtin
5555
config: Path,
56-
quantization,
56+
quantization: Quantization,
5757
model_type: Model,
5858
target: Target,
5959
opt: OptimizationFlags,

python/mlc_chat/compiler/model/llama_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def __init__(self, config: LlamaConfig):
2222

2323
def forward(self, q: Tensor, k: Tensor, offset: tir.Var):
2424
def te_op(x: te.Tensor, offset: tir.Var):
25+
dtype = x.dtype
26+
2527
def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var):
2628
head_dim = tir.const(self.head_dim, "int32")
2729
position_embedding_base = tir.const(self.position_embedding_base, "float32")
@@ -30,11 +32,13 @@ def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var):
3032
(d * 2 % head_dim).astype("float32") / head_dim,
3133
)
3234
freq = (offset + s) / freq
33-
return tir.cos(freq) * x[b, s, h, d] + tir.sin(freq) * tir.if_then_else(
35+
cos = tir.cos(freq).astype(dtype) * x[b, s, h, d]
36+
sin = tir.sin(freq).astype(dtype) * tir.if_then_else(
3437
d < self.head_dim // 2,
3538
-x[b, s, h, d + self.head_dim // 2],
3639
x[b, s, h, d - self.head_dim // 2],
3740
)
41+
return cos + sin
3842

3943
return te.compute(x.shape, compute, name="rotary")
4044

@@ -87,6 +91,7 @@ def forward( # pylint: disable=too-many-locals
8791
d, h_q, h_kv, t = self.head_dim, self.num_q_heads, self.num_kv_heads, total_seq_len
8892
b, s, _ = hidden_states.shape
8993
assert b == 1, "Only support batch size 1 at this moment."
94+
9095
q, k, v = self.qkv_proj(hidden_states)
9196
q = op.reshape(q, (b, s, h_q, d))
9297
k = op.reshape(k, (b, s, h_kv, d))
Lines changed: 16 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,24 @@
1-
"""
2-
Quantization specs for Llama2 architecture.
3-
TODO: add docstring
4-
"""
5-
from typing import Callable, Dict, List, Optional
1+
"""Quantization specs for Llama."""
2+
from typing import Tuple
63

7-
import tvm
8-
from tvm.runtime import NDArray
4+
from tvm.relax.frontend import nn
95

106
from ..parameter import QuantizeMapping
11-
from ..quantization import QuantizeConfig
12-
from ..quantization.group_quantizer import te_quantize as te_group_quantize
7+
from ..quantization import GroupQuantize
138
from .llama_config import LlamaConfig
149
from .llama_model import LlamaForCasualLM
1510

1611

17-
def huggingface_group_quantize(
12+
def group_quant(
1813
model_config: LlamaConfig,
19-
quantize_config: QuantizeConfig,
20-
target: Optional[tvm.target.Target] = None,
21-
) -> QuantizeMapping:
22-
"""Returns a parameter mapping that maps a parameter in MLC LLM's model
23-
definition to its eventual names and values after quantization.
24-
25-
Parameters
26-
----------
27-
model_config : LlamaConfig
28-
The configuration of the Llama model.
29-
quantize_config : GroupQuantizeConfig
30-
The configuration of the group quantization.
31-
target : Optional[tvm.target.Target]
32-
The target device to run the quantization on, by default None, which
33-
means the quantization will be run on CPU.
34-
35-
Returns
36-
-------
37-
quantize_map : QuantizeMapping
38-
The parameter mapping from a parameter in MLC LLM's model definition to
39-
its eventual names and values after quantization.
40-
"""
41-
42-
def group_quantize(
43-
param: NDArray, config: QuantizeConfig, target: Optional[tvm.target.Target] = None
44-
):
45-
if target is None or target.kind.name == "llvm":
46-
target = tvm.target.Target("llvm")
47-
device = tvm.cpu()
48-
elif target.kind.name == "cuda":
49-
device = tvm.cuda()
50-
else:
51-
raise ValueError(f"Invalid target device: {target}")
52-
param_tensor = tvm.te.placeholder(param.shape, dtype=param.dtype, name="param")
53-
weight_compute, scale_compute, other_computes = te_group_quantize( # type: ignore
54-
param_tensor, config
55-
)
56-
s = tvm.te.create_schedule( # pylint: disable=invalid-name
57-
[compute.op for compute in [weight_compute, scale_compute] + other_computes]
58-
)
59-
if target.kind.name == "cuda":
60-
# thread_binding for cuda
61-
for compute in [weight_compute, scale_compute] + other_computes:
62-
xo, xi = s[compute].split(compute.op.axis[0], 256) # pylint: disable=invalid-name
63-
s[compute].bind(xo, tvm.te.thread_axis("blockIdx.x"))
64-
s[compute].bind(xi, tvm.te.thread_axis("threadIdx.x"))
65-
f_quantize = tvm.build(
66-
s, [param_tensor, weight_compute, scale_compute], name="group_quantize", target=target
67-
)
68-
weight = tvm.nd.empty(weight_compute.shape, weight_compute.dtype, device=device)
69-
scale = tvm.nd.empty(scale_compute.shape, scale_compute.dtype, device=device)
70-
f_quantize(param.copyto(device), weight, scale)
71-
return weight, scale
72-
73-
# Param check
74-
assert (
75-
quantize_config.kind == "group_quantize"
76-
), f"Invalid quantization config: group quantization expected but got {quantize_config.kind}"
77-
assert (
78-
quantize_config.name == "q4f16_1"
79-
), """Only support q4f16_1 quantization scheme for now."""
80-
81-
# Fetch model parameter & names
82-
model = LlamaForCasualLM(model_config)
83-
_, named_params = model.export_tvm(spec=model.get_default_spec())
84-
parameter_names = {name for name, _ in named_params}
85-
86-
# Init mappings
87-
param_map: Dict[str, List[str]] = {}
88-
map_func: Dict[str, Callable] = {}
89-
90-
# Dispatch quantization scheme
91-
# Also see https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/quantization/__init__.py
92-
for name in parameter_names:
93-
if "norm.weight" not in name and "embed" not in name:
94-
param_map[name] = [f"{name}_quantized", f"{name}_scale"]
95-
map_func[name] = lambda x: group_quantize(x, quantize_config, target=target)
96-
else:
97-
# skip these parameters
98-
param_map[name] = [name]
99-
map_func[name] = lambda x: [x]
100-
101-
return QuantizeMapping(param_map, map_func)
14+
quantization: GroupQuantize,
15+
) -> Tuple[nn.Module, QuantizeMapping]:
16+
"""Quantize a Llama2 model using group quantization."""
17+
model: nn.Module = LlamaForCasualLM(model_config)
18+
quant_map = QuantizeMapping({}, {})
19+
model = quantization.quantize_model(
20+
model,
21+
quant_map,
22+
"model",
23+
)
24+
return model, quant_map

python/mlc_chat/compiler/model/model.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""A centralized registry of all existing model architures and their configurations."""
22
import dataclasses
3-
from typing import Any, Callable, Dict
3+
from typing import Any, Callable, Dict, Tuple
44

55
from tvm.relax.frontend import nn
66

77
from ..parameter import ExternMapping, QuantizeMapping
8-
from ..quantization.quantization import QuantizeConfig
9-
from . import llama_config, llama_model, llama_parameter
8+
from ..quantization.quantization import Quantization
9+
from . import llama_config, llama_model, llama_parameter, llama_quantization
1010

1111
ModelConfig = Any
1212
"""A ModelConfig is an object that represents a model architecture. It is required to have
@@ -16,8 +16,8 @@ def from_file(cls, path: Path) -> ModelConfig:
1616
...
1717
"""
1818

19-
FuncGetExternMap = Callable[[ModelConfig, QuantizeConfig], ExternMapping]
20-
FuncGetQuantMap = Callable[[ModelConfig, QuantizeConfig], QuantizeMapping]
19+
FuncGetExternMap = Callable[[ModelConfig, Quantization], ExternMapping]
20+
FuncQuantization = Callable[[ModelConfig, Quantization], Tuple[nn.Module, QuantizeMapping]]
2121

2222

2323
@dataclasses.dataclass
@@ -38,15 +38,16 @@ class Model:
3838
source : Dict[str, FuncGetExternMap]
3939
A dictionary that maps the name of a source format to parameter mapping.
4040
41-
quantize: Dict[str, FuncGetQuantMap]
42-
A dictionary that maps the name of a quantization method to quantization mapping.
41+
quantize: Dict[str, FuncQuantization]
42+
A dictionary that maps the name of a quantization method to quantized model and the
43+
quantization parameter mapping.
4344
"""
4445

4546
name: str
4647
config: ModelConfig
4748
model: Callable[[ModelConfig], nn.Module]
4849
source: Dict[str, FuncGetExternMap]
49-
quantize: Dict[str, FuncGetQuantMap]
50+
quantize: Dict[str, FuncQuantization]
5051

5152

5253
MODELS: Dict[str, Model] = {
@@ -58,7 +59,9 @@ class Model:
5859
"huggingface-torch": llama_parameter.huggingface,
5960
"huggingface-safetensor": llama_parameter.huggingface,
6061
},
61-
quantize={},
62+
quantize={
63+
"group-quant": llama_quantization.group_quant,
64+
},
6265
)
6366
}
6467

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
"""A subpackage for quantization and dequantization algorithms"""
2-
from .quantization import QUANT, QuantizeConfig
2+
from .group_quantization import GroupQuantize
3+
from .quantization import QUANTIZATION, Quantization

0 commit comments

Comments
 (0)