Skip to content

Commit 4096249

Browse files
committed
Update
1 parent dd79bc6 commit 4096249

File tree

8 files changed

+210
-181
lines changed

8 files changed

+210
-181
lines changed

python/mlc_chat/cli/compile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from mlc_chat.compiler import ( # pylint: disable=redefined-builtin
88
MODELS,
9-
QUANT,
9+
QUANTIZATION,
1010
OptimizationFlags,
1111
compile,
1212
)
@@ -51,7 +51,7 @@ def _parse_output(path: Union[str, Path]) -> Path:
5151
"--quantization",
5252
type=str,
5353
required=True,
54-
choices=list(QUANT.keys()),
54+
choices=list(QUANTIZATION.keys()),
5555
help="Quantization format.",
5656
)
5757
parser.add_argument(
@@ -119,7 +119,7 @@ def _parse_output(path: Union[str, Path]) -> Path:
119119
parsed.model_type = detect_model_type(parsed.model_type, parsed.config)
120120
compile(
121121
config=parsed.config,
122-
quantization=parsed.quantization,
122+
quantization=QUANTIZATION[parsed.quantization],
123123
model_type=parsed.model_type,
124124
target=target,
125125
opt=parsed.opt,

python/mlc_chat/compiler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
from .flags_optimization import OptimizationFlags
88
from .model import MODEL_PRESETS, MODELS, Model
99
from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping
10-
from .quantization import QUANT
10+
from .quantization import QUANTIZATION

python/mlc_chat/compiler/compile.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@
77
from tvm import IRModule, relax
88
from tvm.target import Target
99

10-
from ..compiler.model import Model
1110
from ..support.style import bold
1211
from .flags_optimization import OptimizationFlags
12+
from .model import Model
13+
from .quantization import Quantization
1314

1415

1516
@dataclasses.dataclass
1617
class CompileArgs: # pylint: disable=too-many-instance-attributes
1718
"""Arguments to MLC LLM's compiler."""
1819

1920
config: Path
20-
quantization: str
21+
quantization: Quantization
2122
model: Model
2223
target: Target
2324
opt: OptimizationFlags
@@ -40,20 +41,19 @@ def _echo_args(args: CompileArgs) -> None:
4041

4142
def _compile(args: CompileArgs):
4243
model_config = args.model.config.from_file(args.config)
43-
model = args.model.model(model_config)
44-
mod, named_params = model.export_tvm(
44+
quantization = args.quantization
45+
model, _ = args.model.quantize[quantization.kind](model_config, quantization)
46+
mod, _named_params = model.export_tvm(
4547
spec=model.get_default_spec(), # type: ignore
4648
)
4749
with args.target:
4850
mod = relax.get_pipeline("mlc_llm")(mod)
49-
mod.show(black_format=False)
50-
for name, param in named_params:
51-
print(f"{name}: {param.shape} {param.dtype}")
51+
args.build_func(mod, args)
5252

5353

5454
def compile( # pylint: disable=too-many-arguments,redefined-builtin
5555
config: Path,
56-
quantization,
56+
quantization: Quantization,
5757
model_type: Model,
5858
target: Target,
5959
opt: OptimizationFlags,
Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
1-
"""
2-
Quantization specs for Llama2 architecture.
3-
TODO: add docstring
4-
"""
1+
"""Quantization specs for Llama."""
52
from typing import Tuple
63

74
from tvm.relax.frontend import nn
85

96
from ..parameter import QuantizeMapping
10-
from ..quantization import GroupQuantizeConfig
7+
from ..quantization import GroupQuantize
8+
from .llama_config import LlamaConfig
119
from .llama_model import LlamaForCasualLM
1210

1311

14-
def llama_group_quantization(
15-
model: LlamaForCasualLM, quant_config: GroupQuantizeConfig
12+
def group_quant(
13+
model_config: LlamaConfig,
14+
quantization: GroupQuantize,
1615
) -> Tuple[nn.Module, QuantizeMapping]:
16+
"""Quantize a Llama2 model using group quantization."""
17+
model: nn.Module = LlamaForCasualLM(model_config)
1718
quant_map = QuantizeMapping({}, {})
18-
for i in range(len(model.model.layers)):
19-
model.model.layers[i] = quant_config.apply(
20-
model.model.layers[i], quant_map, f"model.layers.{i}"
21-
)
22-
model.model.embed_tokens = quant_config.apply(
23-
model.model.embed_tokens, quant_map, "model.embed_tokens"
19+
model = quantization.apply(
20+
model,
21+
quant_map,
22+
"model",
2423
)
25-
model.lm_head = quant_config.apply(model.lm_head, quant_map, "lm_head")
2624
return model, quant_map

python/mlc_chat/compiler/model/model.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""A centralized registry of all existing model architures and their configurations."""
22
import dataclasses
3-
from typing import Any, Callable, Dict
3+
from typing import Any, Callable, Dict, Tuple
44

55
from tvm.relax.frontend import nn
66

77
from ..parameter import ExternMapping, QuantizeMapping
8-
from ..quantization.quantization import QuantizeConfig
9-
from . import llama_config, llama_model, llama_parameter
8+
from ..quantization.quantization import Quantization
9+
from . import llama_config, llama_model, llama_parameter, llama_quantization
1010

1111
ModelConfig = Any
1212
"""A ModelConfig is an object that represents a model architecture. It is required to have
@@ -16,8 +16,8 @@ def from_file(cls, path: Path) -> ModelConfig:
1616
...
1717
"""
1818

19-
FuncGetExternMap = Callable[[ModelConfig, QuantizeConfig], ExternMapping]
20-
FuncGetQuantMap = Callable[[ModelConfig, QuantizeConfig], QuantizeMapping]
19+
FuncGetExternMap = Callable[[ModelConfig, Quantization], ExternMapping]
20+
FuncQuantization = Callable[[ModelConfig, Quantization], Tuple[nn.Module, QuantizeMapping]]
2121

2222

2323
@dataclasses.dataclass
@@ -38,15 +38,16 @@ class Model:
3838
source : Dict[str, FuncGetExternMap]
3939
A dictionary that maps the name of a source format to parameter mapping.
4040
41-
quantize: Dict[str, FuncGetQuantMap]
42-
A dictionary that maps the name of a quantization method to quantization mapping.
41+
quantize: Dict[str, FuncQuantization]
42+
A dictionary that maps the name of a quantization method to quantized model and the
43+
quantization parameter mapping.
4344
"""
4445

4546
name: str
4647
config: ModelConfig
4748
model: Callable[[ModelConfig], nn.Module]
4849
source: Dict[str, FuncGetExternMap]
49-
quantize: Dict[str, FuncGetQuantMap]
50+
quantize: Dict[str, FuncQuantization]
5051

5152

5253
MODELS: Dict[str, Model] = {
@@ -58,7 +59,9 @@ class Model:
5859
"huggingface-torch": llama_parameter.huggingface,
5960
"huggingface-safetensor": llama_parameter.huggingface,
6061
},
61-
quantize={},
62+
quantize={
63+
"group-quant": llama_quantization.group_quant,
64+
},
6265
)
6366
}
6467

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""A subpackage for quantization and dequantization algorithms"""
2-
from .quantization import QUANT, QuantizeConfig
3-
from .group_quantization import GroupQuantizeConfig
2+
from .group_quantization import GroupQuantize
3+
from .quantization import QUANTIZATION, Quantization

0 commit comments

Comments
 (0)