Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():

def _parse_config(path: Union[str, Path]) -> Path:
try:
return detect_config(Path(path))
return detect_config(path)
except ValueError as err:
raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}")

Expand Down
9 changes: 3 additions & 6 deletions python/mlc_chat/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
but users could optionally import it if they want to use the compiler.
"""
from . import compiler_pass
from .compile import ( # pylint: disable=redefined-builtin
CompileArgs,
OptimizationFlags,
compile,
)
from .model import MODELS, Model
from .compile import CompileArgs, compile # pylint: disable=redefined-builtin
from .flags_optimization import OptimizationFlags
from .model import MODEL_PRESETS, MODELS, Model
from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping
from .quantization import QUANT
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Model definition for the compiler."""
from .model import MODELS, Model
from .model import MODEL_PRESETS, MODELS, Model
2 changes: 2 additions & 0 deletions python/mlc_chat/compiler/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,5 @@ class Model:
quantize={},
)
}

MODEL_PRESETS: Dict[str, Dict[str, Any]] = llama_config.CONFIG
30 changes: 24 additions & 6 deletions python/mlc_chat/support/auto_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Help function for detecting the model configuration file `config.json`"""
import json
import logging
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union

from .style import green

Expand All @@ -14,25 +15,42 @@
FOUND = green("Found")


def detect_config(config_path: Path) -> Path:
"""Detect and return the path that points to config.json. If config_path is a directory,
def detect_config(config: Union[str, Path]) -> Path:
"""Detect and return the path that points to config.json. If `config` is a directory,
it looks for config.json below it.

Parameters
---------
config_path : pathlib.Path
The path to config.json or the directory containing config.json.
config : Union[str, pathlib.Path]
The preset name of the model, or the path to `config.json`, or the directory containing
`config.json`.

Returns
-------
config_json_path : pathlib.Path
The path points to config.json.
"""
from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel
MODEL_PRESETS,
)

if isinstance(config, str) and config in MODEL_PRESETS:
content = MODEL_PRESETS[config]
temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with
suffix=".json",
delete=False,
)
logger.info("%s preset model configuration: %s", FOUND, temp_file.name)
config_path = Path(temp_file.name)
with config_path.open("w", encoding="utf-8") as config_file:
json.dump(content, config_file, indent=2)
else:
config_path = Path(config)
if not config_path.exists():
raise ValueError(f"{config_path} does not exist.")

if config_path.is_dir():
# search config.json under config_path
# search config.json under config path
config_json_path = config_path / "config.json"
if not config_json_path.exists():
raise ValueError(f"Fail to find config.json under {config_path}.")
Expand Down