|
1 | 1 | """Help function for detecting the model configuration file `config.json`"""
|
2 | 2 | import json
|
3 | 3 | import logging
|
| 4 | +import tempfile |
4 | 5 | from pathlib import Path
|
5 |
| -from typing import TYPE_CHECKING |
| 6 | +from typing import TYPE_CHECKING, Union |
6 | 7 |
|
7 | 8 | from .style import green
|
8 | 9 |
|
|
14 | 15 | FOUND = green("Found")
|
15 | 16 |
|
16 | 17 |
|
17 |
| -def detect_config(config_path: Path) -> Path: |
18 |
| - """Detect and return the path that points to config.json. If config_path is a directory, |
| 18 | +def detect_config(config: Union[str, Path]) -> Path: |
| 19 | + """Detect and return the path that points to config.json. If `config` is a directory, |
19 | 20 | it looks for config.json below it.
|
20 | 21 |
|
21 | 22 | Parameters
|
22 | 23 | ---------
|
23 |
| - config_path : pathlib.Path |
24 |
| - The path to config.json or the directory containing config.json. |
| 24 | + config : Union[str, pathlib.Path] |
| 25 | + The preset name of the model, or the path to `config.json`, or the directory containing |
| 26 | + `config.json`. |
25 | 27 |
|
26 | 28 | Returns
|
27 | 29 | -------
|
28 | 30 | config_json_path : pathlib.Path
|
29 | 31 | The path points to config.json.
|
30 | 32 | """
|
| 33 | + from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel |
| 34 | + MODEL_PRESETS, |
| 35 | + ) |
| 36 | + |
| 37 | + if isinstance(config, str) and config in MODEL_PRESETS: |
| 38 | + content = MODEL_PRESETS[config] |
| 39 | + temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with |
| 40 | + suffix=".json", |
| 41 | + delete=False, |
| 42 | + ) |
| 43 | + logger.info("%s preset model configuration: %s", FOUND, temp_file.name) |
| 44 | + config_path = Path(temp_file.name) |
| 45 | + with config_path.open("w", encoding="utf-8") as config_file: |
| 46 | + json.dump(content, config_file, indent=2) |
| 47 | + else: |
| 48 | + config_path = Path(config) |
31 | 49 | if not config_path.exists():
|
32 | 50 | raise ValueError(f"{config_path} does not exist.")
|
33 | 51 |
|
34 | 52 | if config_path.is_dir():
|
35 |
| - # search config.json under config_path |
| 53 | + # search config.json under config path |
36 | 54 | config_json_path = config_path / "config.json"
|
37 | 55 | if not config_json_path.exists():
|
38 | 56 | raise ValueError(f"Fail to find config.json under {config_path}.")
|
|
0 commit comments