Skip to content

Commit a8c1d66

Browse files
committed
Compile Model Preset without External config.json
This PR adds support for compiling a preset of models without having to provide a `config.json` on disk using the commands below: ```diff python -m mlc_chat.cli.compile \ --quantization q4f16_1 -o /tmp/1.so \ - --config /models/Llama-2-7b-chat-hf + --config llama2_7b ``` This allows easier testing and binary distribution without having to depend on external model directory.
1 parent 0a25374 commit a8c1d66

File tree

5 files changed

+31
-14
lines changed

5 files changed

+31
-14
lines changed

python/mlc_chat/cli/compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def main():
2727

2828
def _parse_config(path: Union[str, Path]) -> Path:
2929
try:
30-
return detect_config(Path(path))
30+
return detect_config(path)
3131
except ValueError as err:
3232
raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}")
3333

python/mlc_chat/compiler/__init__.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
but users could optionally import it if they want to use the compiler.
44
"""
55
from . import compiler_pass
6-
from .compile import ( # pylint: disable=redefined-builtin
7-
CompileArgs,
8-
OptimizationFlags,
9-
compile,
10-
)
11-
from .model import MODELS, Model
6+
from .compile import CompileArgs, compile # pylint: disable=redefined-builtin
7+
from .flags_optimization import OptimizationFlags
8+
from .model import MODEL_PRESETS, MODELS, Model
129
from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping
1310
from .quantization import QUANT
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""Model definition for the compiler."""
2-
from .model import MODELS, Model
2+
from .model import MODEL_PRESETS, MODELS, Model

python/mlc_chat/compiler/model/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,5 @@ class Model:
6161
quantize={},
6262
)
6363
}
64+
65+
MODEL_PRESETS: Dict[str, Dict[str, Any]] = llama_config.CONFIG

python/mlc_chat/support/auto_config.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Help function for detecting the model configuration file `config.json`"""
22
import json
33
import logging
4+
import tempfile
45
from pathlib import Path
5-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, Union
67

78
from .style import green
89

@@ -14,25 +15,42 @@
1415
FOUND = green("Found")
1516

1617

17-
def detect_config(config_path: Path) -> Path:
18-
"""Detect and return the path that points to config.json. If config_path is a directory,
18+
def detect_config(config: Union[str, Path]) -> Path:
19+
"""Detect and return the path that points to config.json. If `config` is a directory,
1920
it looks for config.json below it.
2021
2122
Parameters
2223
---------
23-
config_path : pathlib.Path
24-
The path to config.json or the directory containing config.json.
24+
config : Union[str, pathlib.Path]
25+
The preset name of the model, or the path to `config.json`, or the directory containing
26+
`config.json`.
2527
2628
Returns
2729
-------
2830
config_json_path : pathlib.Path
2931
The path points to config.json.
3032
"""
33+
from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel
34+
MODEL_PRESETS,
35+
)
36+
37+
if isinstance(config, str) and config in MODEL_PRESETS:
38+
content = MODEL_PRESETS[config]
39+
temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with
40+
suffix=".json",
41+
delete=False,
42+
)
43+
logger.info("%s preset model configuration: %s", FOUND, temp_file.name)
44+
config_path = Path(temp_file.name)
45+
with config_path.open("w", encoding="utf-8") as config_file:
46+
json.dump(content, config_file, indent=2)
47+
else:
48+
config_path = Path(config)
3149
if not config_path.exists():
3250
raise ValueError(f"{config_path} does not exist.")
3351

3452
if config_path.is_dir():
35-
# search config.json under config_path
53+
# search config.json under config path
3654
config_json_path = config_path / "config.json"
3755
if not config_json_path.exists():
3856
raise ValueError(f"Fail to find config.json under {config_path}.")

0 commit comments

Comments
 (0)