Skip to content

Commit 16dd2ae

Browse files
authored
[Slim-LM] Smart path finding for config and weight (#1088)
1 parent 6159cc4 commit 16dd2ae

File tree

4 files changed

+307
-0
lines changed

4 files changed

+307
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""Help function for detecting the model configuration file `config.json`"""
2+
import logging
3+
from pathlib import Path
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def detect_config(config_path: Path) -> Path:
9+
"""Detect and return the path that points to config.json. If config_path is a directory,
10+
it looks for config.json below it.
11+
12+
Parameters
13+
---------
14+
config_path : pathlib.Path
15+
The path to config.json or the directory containing config.json.
16+
17+
Returns
18+
-------
19+
config_json_path : pathlib.Path
20+
The path points to config.json.
21+
"""
22+
if not config_path.exists():
23+
raise ValueError(f"{config_path} does not exist.")
24+
25+
if config_path.is_dir():
26+
# search config.json under config_path
27+
config_json_path = config_path / "config.json"
28+
if not config_json_path.exists():
29+
raise ValueError(f"Fail to find config.json under {config_path}.")
30+
else:
31+
config_json_path = config_path
32+
33+
logger.info("Found config.json: %s", config_json_path)
34+
return config_json_path
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""Help functions for detecting weight paths and weight formats."""
2+
import json
3+
import logging
4+
from pathlib import Path
5+
from typing import Tuple
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def detect_weight(
11+
weight_path: Path, config_json_path: Path, weight_format: str = "auto"
12+
) -> Tuple[Path, str]:
13+
"""Detect the weight directory, and detect the weight format.
14+
15+
Parameters
16+
---------
17+
weight_path : pathlib.Path
18+
The path to weight files. If `weight_path` is not None, check if it exists. Otherwise, find
19+
`weight_path` in `config.json` or use the same directory as `config.json`.
20+
21+
config_json_path: pathlib.Path
22+
The path to `config.json`.
23+
24+
weight_format : str
25+
The hint for the weight format. If it is "auto", guess the weight format.
26+
Otherwise, check the weights are in that format.
27+
Available weight formats:
28+
- auto (guess the weight format)
29+
- PyTorch (validate via checking pytorch_model.bin.index.json)
30+
- SafeTensor (validate via checking model.safetensors.index.json)
31+
- AWQ
32+
- GGML/GGUF
33+
34+
Returns
35+
-------
36+
weight_path : pathlib.Path
37+
The path that points to the weights.
38+
39+
weight_format : str
40+
The valid weight format.
41+
"""
42+
if weight_path is None:
43+
assert (
44+
config_json_path is not None and config_json_path.exists()
45+
), "Please provide config.json path."
46+
47+
# 1. Find the weight_path in config.json
48+
with open(config_json_path, encoding="utf-8") as i_f:
49+
config = json.load(i_f)
50+
if "weight_path" in config:
51+
weight_path = Path(config["weight_path"])
52+
logger.info('Found "weight_path" in config.json: %s', weight_path)
53+
if not weight_path.exists():
54+
raise ValueError(f"weight_path doesn't exist: {weight_path}")
55+
else:
56+
# 2. Find the weights file in the same directory as config.json
57+
weight_path = config_json_path.parent
58+
else:
59+
if not weight_path.exists():
60+
raise ValueError(f"weight_path doesn't exist: {weight_path}")
61+
62+
logger.info("Loading weights from directory: %s", weight_path)
63+
64+
# check weight format
65+
# weight_format = "auto", guess the weight format.
66+
# otherwise, check the weight format is valid.
67+
if weight_format == "auto":
68+
weight_format = _guess_weight_format(weight_path)
69+
70+
if weight_format not in AVAILABLE_WEIGHT_FORMAT:
71+
raise ValueError(
72+
f"Available weight format list: {AVAILABLE_WEIGHT_FORMAT}, but got {weight_format}"
73+
)
74+
if weight_format in CHECK_FORMAT_METHODS:
75+
check_func = CHECK_FORMAT_METHODS[weight_format]
76+
if not check_func(weight_path):
77+
raise ValueError(f"The weight is not in {weight_format} format.")
78+
return weight_path, weight_format
79+
80+
81+
def _guess_weight_format(weight_path: Path):
82+
possible_formats = []
83+
for weight_format, check_func in CHECK_FORMAT_METHODS.items():
84+
if check_func(weight_path):
85+
possible_formats.append(weight_format)
86+
87+
if len(possible_formats) == 0:
88+
raise ValueError(
89+
"Fail to detect weight format. Use `--weight-format` to manually specify the format."
90+
)
91+
92+
selected_format = possible_formats[0]
93+
logging.info(
94+
"Using %s format now. Use `--weight-format` to manually specify the format.",
95+
selected_format,
96+
)
97+
return selected_format
98+
99+
100+
def _check_pytorch(weight_path: Path):
101+
pytorch_json_path = weight_path / "pytorch_model.bin.index.json"
102+
result = pytorch_json_path.exists()
103+
if result:
104+
logger.info("[Y] Found Huggingface PyTorch: %s", pytorch_json_path)
105+
else:
106+
logger.info("[X] Not found: Huggingface PyTorch")
107+
return result
108+
109+
110+
def _check_safetensor(weight_path: Path):
111+
safetensor_json_path = weight_path / "model.safetensors.index.json"
112+
result = safetensor_json_path.exists()
113+
if result:
114+
logger.info("[Y] Found SafeTensor: %s", safetensor_json_path)
115+
else:
116+
logger.info("[X] Not found: SafeTensor")
117+
return result
118+
119+
120+
CHECK_FORMAT_METHODS = {
121+
"PyTorch": _check_pytorch,
122+
"SafeTensor": _check_safetensor,
123+
}
124+
125+
AVAILABLE_WEIGHT_FORMAT = ["PyTorch", "SafeTensor", "GGML", "GGUF", "AWQ"]

tests/python/test_auto_config.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# pylint: disable=missing-docstring
2+
import json
3+
import logging
4+
import tempfile
5+
from pathlib import Path
6+
7+
import pytest
8+
from mlc_chat.support.auto_config import detect_config
9+
10+
logging.basicConfig(
11+
level=logging.INFO,
12+
style="{",
13+
datefmt="%Y-%m-%d %H:%M:%S",
14+
format="{asctime} {levelname} {filename}:{lineno}: {message}",
15+
)
16+
17+
18+
def _create_json_file(json_path, data):
19+
with open(json_path, "w", encoding="utf-8") as i_f:
20+
json.dump(data, i_f)
21+
22+
23+
def test_detect_config():
24+
with tempfile.TemporaryDirectory() as tmpdir:
25+
base_path = Path(tmpdir)
26+
config_json_path = base_path / "config.json"
27+
_create_json_file(config_json_path, {})
28+
29+
assert detect_config(base_path) == config_json_path
30+
assert detect_config(config_json_path) == config_json_path
31+
32+
33+
def test_detect_config_fail():
34+
with pytest.raises(ValueError):
35+
detect_config(Path("do/not/exist"))
36+
37+
with tempfile.TemporaryDirectory() as tmpdir:
38+
base_path = Path(tmpdir)
39+
with pytest.raises(ValueError):
40+
assert detect_config(base_path)
41+
42+
43+
if __name__ == "__main__":
44+
pass

tests/python/test_auto_weight.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# pylint: disable=missing-docstring
2+
import json
3+
import logging
4+
import os
5+
import tempfile
6+
from pathlib import Path
7+
8+
import pytest
9+
from mlc_chat.support.auto_weight import detect_weight
10+
11+
logging.basicConfig(
12+
level=logging.INFO,
13+
style="{",
14+
datefmt="%Y-%m-%d %H:%M:%S",
15+
format="{asctime} {levelname} {filename}:{lineno}: {message}",
16+
)
17+
18+
19+
def _create_json_file(json_path, data):
20+
with open(json_path, "w", encoding="utf-8") as i_f:
21+
json.dump(data, i_f)
22+
23+
24+
@pytest.mark.parametrize(
25+
"weight_format, index_filename, result",
26+
[
27+
("PyTorch", "pytorch_model.bin.index.json", "PyTorch"),
28+
("SafeTensor", "model.safetensors.index.json", "SafeTensor"),
29+
("GGML", None, "GGML"),
30+
("GGUF", None, "GGUF"),
31+
("AWQ", None, "AWQ"),
32+
("auto", "pytorch_model.bin.index.json", "PyTorch"),
33+
("auto", "model.safetensors.index.json", "SafeTensor"),
34+
],
35+
)
36+
def test_detect_weight(weight_format, index_filename, result):
37+
with tempfile.TemporaryDirectory() as tmpdir:
38+
base_path = Path(tmpdir)
39+
if index_filename is not None:
40+
weight_index_file = base_path / index_filename
41+
_create_json_file(weight_index_file, {})
42+
assert detect_weight(base_path, None, weight_format) == (base_path, result)
43+
44+
45+
@pytest.mark.parametrize(
46+
"weight_format, index_filename, result",
47+
[
48+
("PyTorch", "pytorch_model.bin.index.json", "PyTorch"),
49+
("SafeTensor", "model.safetensors.index.json", "SafeTensor"),
50+
("GGML", None, "GGML"),
51+
("GGUF", None, "GGUF"),
52+
("AWQ", None, "AWQ"),
53+
("auto", "pytorch_model.bin.index.json", "PyTorch"),
54+
("auto", "model.safetensors.index.json", "SafeTensor"),
55+
],
56+
)
57+
def test_detect_weight_in_config_json(weight_format, index_filename, result):
58+
with tempfile.TemporaryDirectory() as config_dir, tempfile.TemporaryDirectory() as weight_dir:
59+
config_path = Path(config_dir)
60+
weight_path = Path(weight_dir)
61+
config_json_path = config_path / "config.json"
62+
_create_json_file(config_json_path, {"weight_path": weight_dir})
63+
if index_filename is not None:
64+
weight_index_file = weight_path / index_filename
65+
_create_json_file(weight_index_file, {})
66+
67+
assert detect_weight(None, config_json_path, weight_format) == (weight_path, result)
68+
69+
70+
@pytest.mark.parametrize(
71+
"weight_format, index_filename, result",
72+
[
73+
("PyTorch", "pytorch_model.bin.index.json", "PyTorch"),
74+
("SafeTensor", "model.safetensors.index.json", "SafeTensor"),
75+
("GGML", None, "GGML"),
76+
("GGUF", None, "GGUF"),
77+
("AWQ", None, "AWQ"),
78+
("auto", "pytorch_model.bin.index.json", "PyTorch"),
79+
("auto", "model.safetensors.index.json", "SafeTensor"),
80+
],
81+
)
82+
def test_detect_weight_same_dir_config_json(weight_format, index_filename, result):
83+
with tempfile.TemporaryDirectory() as tmpdir:
84+
base_path = Path(tmpdir)
85+
config_json_path = base_path / "config.json"
86+
_create_json_file(config_json_path, {})
87+
if index_filename is not None:
88+
weight_index_file = os.path.join(tmpdir, index_filename)
89+
_create_json_file(weight_index_file, {})
90+
assert detect_weight(None, config_json_path, weight_format) == (base_path, result)
91+
92+
93+
def test_find_weight_fail():
94+
with tempfile.TemporaryDirectory() as tmpdir:
95+
base_path = Path(tmpdir)
96+
with pytest.raises(ValueError):
97+
detect_weight(Path("do/not/exist"), base_path, "AWQ")
98+
99+
with pytest.raises(AssertionError):
100+
detect_weight(None, Path("do/not/exist"), "AWQ")
101+
102+
103+
if __name__ == "__main__":
104+
pass

0 commit comments

Comments
 (0)