Skip to content

Commit 61179a0

Browse files
authored
Add CLI commands for compilation (mlc-ai#1109)
1 parent 5a7dcd8 commit 61179a0

File tree

14 files changed

+612
-16
lines changed

14 files changed

+612
-16
lines changed

python/mlc_chat/cli/compile.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""Command line entrypoint of compilation."""
2+
import argparse
3+
import json
4+
import logging
5+
from pathlib import Path
6+
from typing import Union
7+
8+
from mlc_chat.compiler.compile import compile # pylint: disable=redefined-builtin
9+
from mlc_chat.compiler.model import MODELS, Model
10+
11+
from ..support.auto_config import detect_config
12+
from ..support.auto_target import detect_target_and_host
13+
14+
logging.basicConfig(
15+
level=logging.DEBUG,
16+
style="{",
17+
datefmt="%Y-%m-%d %H:%M:%S",
18+
format="[{asctime}] {levelname} {filename}:{lineno}: {message}",
19+
)
20+
21+
22+
def _parse_config(path: Union[str, Path]) -> Path:
23+
try:
24+
return detect_config(Path(path))
25+
except ValueError as err:
26+
raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}")
27+
28+
29+
def _parse_output(path: Union[str, Path]) -> Path:
30+
path = Path(path)
31+
parent = path.parent
32+
if not parent.is_dir():
33+
raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}")
34+
return path
35+
36+
37+
def _parse_model_type(model_type: str, config: Path) -> Model:
38+
if model_type == "auto":
39+
with open(config, "r", encoding="utf-8") as config_file:
40+
cfg = json.load(config_file)
41+
if "model_type" not in cfg:
42+
raise ValueError(
43+
f"'model_type' not found in: {config}. "
44+
f"Please explicitly specify `--model-type` instead"
45+
)
46+
model_type = cfg["model_type"]
47+
if model_type not in MODELS:
48+
raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}")
49+
return MODELS[model_type]
50+
51+
52+
def main():
53+
"""Parse command line argumennts and call `mlc_llm.compiler.compile`."""
54+
parser = argparse.ArgumentParser("MLC LLM Compiler")
55+
parser.add_argument(
56+
"--config",
57+
type=_parse_config,
58+
required=True,
59+
help="Path to config.json file or to the directory that contains config.json, which is "
60+
"a HuggingFace standard that defines model architecture, for example, "
61+
"https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json",
62+
)
63+
parser.add_argument(
64+
"--quantization",
65+
type=str,
66+
required=True,
67+
choices=[
68+
"q0f16",
69+
"q0f32",
70+
"q3f16_1",
71+
"q3f32_1",
72+
"q4f16_1",
73+
"q4f16_ft",
74+
"q4f32_1",
75+
],
76+
help="The quantization format. TBD",
77+
)
78+
parser.add_argument(
79+
"--model-type",
80+
type=str,
81+
default="auto",
82+
choices=["auto"] + list(MODELS.keys()),
83+
help="Model architecture, for example, llama. If not set, it is inferred "
84+
"from the config.json file.",
85+
)
86+
parser.add_argument(
87+
"--device",
88+
type=str,
89+
default="auto",
90+
help="The GPU device to compile the model to. If not set, it is inferred from locally "
91+
"available GPUs.",
92+
)
93+
parser.add_argument(
94+
"--host",
95+
type=str,
96+
default="auto",
97+
choices=[
98+
"auto",
99+
"arm",
100+
"arm64",
101+
"aarch64",
102+
"x86-64",
103+
],
104+
help="The host CPU ISA to compile the model to. If not set, it is inferred from the "
105+
"local CPU.",
106+
)
107+
parser.add_argument(
108+
"--opt",
109+
type=str,
110+
default="",
111+
help="Optimization flags.",
112+
)
113+
parser.add_argument(
114+
"--output",
115+
"-o",
116+
type=_parse_output,
117+
required=True,
118+
help="The name of the output file. The suffix determines if the output file is a "
119+
"shared library or a static library. Available suffixes: "
120+
"1) Linux: .so (shared), .a (static); "
121+
"2) macOS: .dylib (shared), .a (static); "
122+
"3) Windows: .dll (shared), .lib (static); "
123+
"4) Android, iOS: .tar (static); "
124+
"5) Web: .wasm (web assembly)",
125+
)
126+
parsed = parser.parse_args()
127+
target, build_func = detect_target_and_host(parsed.device, parsed.host)
128+
parsed.model_type = _parse_model_type(parsed.model_type, parsed.config)
129+
compile(
130+
config=parsed.config,
131+
quantization=parsed.quantization,
132+
model_type=parsed.model_type,
133+
target=target,
134+
opt=parsed.opt,
135+
build_func=build_func,
136+
output=parsed.output,
137+
)
138+
139+
140+
if __name__ == "__main__":
141+
main()

python/mlc_chat/compiler/compile.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Python entrypoint of compilation."""
2+
import dataclasses
3+
import logging
4+
from io import StringIO
5+
from pathlib import Path
6+
from typing import Callable
7+
8+
from mlc_chat.compiler.model import Model
9+
from tvm import IRModule # pylint: disable=wrong-import-order
10+
from tvm.target import Target # pylint: disable=wrong-import-order
11+
12+
from ..support.style import bold
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
@dataclasses.dataclass
18+
class CompileArgs:
19+
"""Arguments to MLC LLM's compiler."""
20+
21+
config: Path
22+
quantization: str
23+
model_type: Model
24+
target: Target
25+
opt: str
26+
build_func: Callable[[IRModule, "CompileArgs"], None]
27+
output: Path
28+
29+
30+
def _echo_args(args: CompileArgs) -> None:
31+
out = StringIO()
32+
print(f"{bold('Compiling with arguments:')}", file=out)
33+
print(f" {bold('--config'):<25} {args.config}", file=out)
34+
print(f" {bold('--quantization'):<25} {args.quantization}", file=out)
35+
print(f" {bold('--model-type'):<25} {args.model_type.name}", file=out)
36+
print(f" {bold('--target'):<25} {args.target.export()}", file=out)
37+
print(f" {bold('--opt'):<25} {args.opt}", file=out)
38+
print(f" {bold('--output'):<25} {args.output}", file=out)
39+
print(out.getvalue().rstrip())
40+
41+
42+
def compile( # pylint: disable=too-many-arguments,redefined-builtin
43+
config: Path,
44+
quantization,
45+
model_type: Model,
46+
target: Target,
47+
opt,
48+
build_func: Callable[[IRModule, CompileArgs], None],
49+
output: Path,
50+
):
51+
"""Compile a model given its configuration and quantization format to a specific target."""
52+
args = CompileArgs(config, quantization, model_type, target, opt, build_func, output)
53+
_echo_args(args)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""Model definition for the compiler."""
2-
from . import llama, llama_config, llama_parameter
2+
from .model import MODELS, Model

python/mlc_chat/compiler/model/llama.py renamed to python/mlc_chat/compiler/model/llama_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,11 @@ def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor
156156

157157

158158
class LlamaForCasualLM(nn.Module):
159-
def __init__(self, config: LlamaConfig, dtype: str = "float32"):
159+
def __init__(self, config: LlamaConfig):
160160
self.model = LlamaModel(config)
161161
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
162162
self.vocab_size = config.vocab_size
163-
self.dtype = dtype
163+
self.dtype = "float32"
164164

165165
def to(self, dtype: Optional[str] = None):
166166
super().to(dtype=dtype)

python/mlc_chat/compiler/model/llama_parameter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import numpy as np
88

99
from ..parameter import ExternMapping
10-
from .llama import LlamaConfig, LlamaForCasualLM
10+
from .llama_config import LlamaConfig
11+
from .llama_model import LlamaForCasualLM
1112

1213

13-
def hf_torch(model_config: LlamaConfig) -> ExternMapping:
14+
def huggingface(model_config: LlamaConfig, _) -> ExternMapping:
1415
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
1516
the names of HuggingFace PyTorch parameters.
1617
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""A centralized registry of all existing model architures and their configurations."""
2+
import dataclasses
3+
from pathlib import Path
4+
from typing import Any, Callable, Dict, Optional
5+
6+
from tvm.relax.frontend import nn
7+
8+
from ..parameter import ExternMapping, QuantizeMapping
9+
from . import llama_config, llama_model, llama_parameter
10+
11+
ModelConfig = Any
12+
QuantizeConfig = Any
13+
14+
LoaderType = Callable[[ModelConfig, QuantizeConfig], ExternMapping]
15+
QuantizerType = Callable[[ModelConfig, QuantizeConfig], QuantizeMapping]
16+
17+
18+
@dataclasses.dataclass
19+
class Model:
20+
"""All about a model architecture: its configuration, its parameter loader and quantization."""
21+
22+
name: str
23+
model: Callable[[ModelConfig], nn.Module]
24+
config: Callable[[Path], ModelConfig]
25+
source_loader_huggingface: Optional[LoaderType] = None
26+
source_loader_awq: Optional[LoaderType] = None
27+
quantizer_group_quant: Optional[QuantizerType] = None
28+
29+
30+
MODELS: Dict[str, Model] = {
31+
"llama": Model(
32+
name="llama",
33+
model=llama_model.LlamaForCasualLM,
34+
config=llama_config.LlamaConfig.from_file,
35+
source_loader_huggingface=llama_parameter.huggingface,
36+
source_loader_awq=None,
37+
quantizer_group_quant=None,
38+
)
39+
}

python/mlc_chat/compiler/parameter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
A subpackage of the compiler that represents mapping between external parameters, quantized
33
parameters and parameters in MLC-defined models.
44
"""
5-
from .hf_loader import HFLoader
5+
from .huggingface_loader import HuggingFaceLoader
66
from .mapping import ExternMapping, QuantizeMapping

python/mlc_chat/compiler/parameter/hf_loader.py renamed to python/mlc_chat/compiler/parameter/huggingface_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
logger = logging.getLogger(__name__)
2020

2121

22-
class HFLoader: # pylint: disable=too-few-public-methods
22+
class HuggingFaceLoader: # pylint: disable=too-few-public-methods
2323
"""A loader loading HuggingFace's PyTorch/SafeTensor format and converts them
2424
to MLC's parameters.
2525
@@ -161,4 +161,4 @@ def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) ->
161161
return list(order.keys())
162162

163163

164-
__all__ = ["HFLoader"]
164+
__all__ = ["HuggingFaceLoader"]

python/mlc_chat/support/auto_config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
import logging
33
from pathlib import Path
44

5+
from .style import green
6+
57
logger = logging.getLogger(__name__)
68

9+
FOUND = green("Found")
10+
711

812
def detect_config(config_path: Path) -> Path:
913
"""Detect and return the path that points to config.json. If config_path is a directory,
@@ -30,5 +34,5 @@ def detect_config(config_path: Path) -> Path:
3034
else:
3135
config_json_path = config_path
3236

33-
logger.info("Found config.json: %s", config_json_path)
37+
logger.info("%s model configuration: %s", FOUND, config_json_path)
3438
return config_json_path

0 commit comments

Comments
 (0)