Skip to content

Commit 3d88fea

Browse files
committed
Add pylint/mypy tooling into pyproject.toml
This PR establishes the initial Python tooling infra with Pylint and Mypy. Currently only the newest modules, i.e. `mlc_chat.support` and `mlc_chat.compiler` are covered, and we expect to cover the entire package, as being tracked in #1101.
1 parent 03c641a commit 3d88fea

File tree

8 files changed

+94
-19
lines changed

8 files changed

+94
-19
lines changed

.github/workflows/python_lint.yml

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
name: Python Lint
2-
32
on: [push, pull_request]
4-
53
env:
6-
IMAGE: 'mlcaidev/ci-cpu:8a87699'
4+
IMAGE: 'mlcaidev/ci-cpu:2c03e7f'
75

86
jobs:
97
isort:
@@ -35,3 +33,33 @@ jobs:
3533
- name: Lint
3634
run: |
3735
./ci/bash.sh $IMAGE bash ./ci/task/black.sh
36+
37+
mypy:
38+
runs-on: ubuntu-latest
39+
steps:
40+
- uses: actions/checkout@v3
41+
with:
42+
submodules: 'recursive'
43+
- name: Version
44+
run: |
45+
wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh
46+
chmod u+x ./ci/bash.sh
47+
./ci/bash.sh $IMAGE "conda env export --name ci-lint"
48+
- name: Lint
49+
run: |
50+
./ci/bash.sh $IMAGE bash ./ci/task/mypy.sh
51+
52+
pylint:
53+
runs-on: ubuntu-latest
54+
steps:
55+
- uses: actions/checkout@v3
56+
with:
57+
submodules: 'recursive'
58+
- name: Version
59+
run: |
60+
wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh
61+
chmod u+x ./ci/bash.sh
62+
./ci/bash.sh $IMAGE "conda env export --name ci-lint"
63+
- name: Lint
64+
run: |
65+
./ci/bash.sh $IMAGE bash ./ci/task/pylint.sh

ci/task/mypy.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/bin/bash
2+
set -eo pipefail
3+
4+
source ~/.bashrc
5+
micromamba activate ci-lint
6+
NUM_THREADS=$(nproc)
7+
8+
mypy ./python/mlc_chat/compiler ./python/mlc_chat/support
9+
mypy ./tests/python/model ./tests/python/parameter

ci/task/pylint.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/bin/bash
2+
set -eo pipefail
3+
4+
source ~/.bashrc
5+
micromamba activate ci-lint
6+
NUM_THREADS=$(nproc)
7+
8+
# TVM Unity is a dependency to this testing
9+
pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly
10+
11+
pylint --jobs $NUM_THREADS ./python/mlc_chat/compiler ./python/mlc_chat/support
12+
pylint --jobs $NUM_THREADS --recursive=y ./tests/python/model ./tests/python/parameter

pyproject.toml

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,19 @@ profile = "black"
1919

2020
[tool.black]
2121
line-length = 100
22-
target-version = ['py310']
22+
23+
[tool.mypy]
24+
ignore_missing_imports = true
25+
show_column_numbers = true
26+
show_error_context = true
27+
follow_imports = "skip"
28+
ignore_errors = false
29+
strict_optional = false
30+
install_types = true
31+
non_interactive = true
32+
33+
[tool.pylint.messages_control]
34+
max-line-length = 100
35+
disable = """
36+
duplicate-code,
37+
"""

python/mlc_chat/compiler/model/llama_parameter.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace
33
PyTorch, HuggingFace safetensors.
44
"""
5+
from typing import Callable, Dict, List
6+
57
import numpy as np
68

79
from ..parameter import ExternMapping
@@ -26,33 +28,33 @@ def hf_torch(model_config: LlamaConfig) -> ExternMapping:
2628
_, named_params = model.export_tvm(spec=model.get_default_spec())
2729
parameter_names = {name for name, _ in named_params}
2830

29-
param_map = {}
30-
map_func = {}
31+
param_map: Dict[str, List[str]] = {}
32+
map_func: Dict[str, Callable] = {}
3133
unused_params = set()
3234

3335
for i in range(model_config.num_hidden_layers):
3436
# Add QKV in self attention
3537
attn = f"model.layers.{i}.self_attn"
3638
assert f"{attn}.qkv_proj.weight" in parameter_names
3739
map_func[f"{attn}.qkv_proj.weight"] = lambda q, k, v: np.concatenate([q, k, v], axis=0)
38-
param_map[f"{attn}.qkv_proj.weight"] = (
40+
param_map[f"{attn}.qkv_proj.weight"] = [
3941
f"{attn}.q_proj.weight",
4042
f"{attn}.k_proj.weight",
4143
f"{attn}.v_proj.weight",
42-
)
44+
]
4345
# Add gates in MLP
4446
mlp = f"model.layers.{i}.mlp"
4547
assert f"{mlp}.gate_up_proj.weight" in parameter_names
4648
map_func[f"{mlp}.gate_up_proj.weight"] = lambda gate, up: np.concatenate([gate, up], axis=0)
47-
param_map[f"{mlp}.gate_up_proj.weight"] = (
49+
param_map[f"{mlp}.gate_up_proj.weight"] = [
4850
f"{mlp}.gate_proj.weight",
4951
f"{mlp}.up_proj.weight",
50-
)
52+
]
5153
# inv_freq is not used in the model
5254
unused_params.add(f"{attn}.rotary_emb.inv_freq")
5355

5456
for name in parameter_names:
5557
if name not in map_func:
5658
map_func[name] = lambda x: x
57-
param_map[name] = (name,)
59+
param_map[name] = [name]
5860
return ExternMapping(param_map, map_func, unused_params)

python/mlc_chat/compiler/parameter/mapping.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
"""Parameter mapping for converting different LLM implementations to MLC LLM."""
22
import dataclasses
3-
from typing import Callable, Dict, List, Set
3+
from typing import Callable, Dict, List, Set, Union
44

55
import numpy as np
66
from tvm.runtime import NDArray
77

8+
MapFuncVariadic = Union[
9+
Callable[[], np.ndarray],
10+
Callable[[np.ndarray], np.ndarray],
11+
Callable[[np.ndarray, np.ndarray], np.ndarray],
12+
Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],
13+
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray],
14+
]
15+
816

917
@dataclasses.dataclass
1018
class ExternMapping:
@@ -33,8 +41,8 @@ class ExternMapping:
3341
"""
3442

3543
param_map: Dict[str, List[str]]
36-
map_func: Dict[str, Callable[[np.ndarray, ...], np.ndarray]]
37-
unused_params: Set[str] = dataclasses.field(default_factory=dict)
44+
map_func: Dict[str, MapFuncVariadic]
45+
unused_params: Set[str] = dataclasses.field(default_factory=set)
3846

3947

4048
@dataclasses.dataclass
@@ -72,8 +80,8 @@ class QuantizeMapping:
7280
used to convert the quantized parameters into the desired form.
7381
"""
7482

75-
param_map: Dict[str, Callable[str, List[str]]]
76-
map_func: Dict[str, Callable[NDArray, List[NDArray]]]
83+
param_map: Dict[str, Callable[[str], List[str]]]
84+
map_func: Dict[str, Callable[[NDArray], List[NDArray]]]
7785

7886

7987
__all__ = ["ExternMapping", "QuantizeMapping"]

python/mlc_chat/support/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def from_dict(cls: Type[ConfigClass], source: Dict[str, Any]) -> ConfigClass:
3737
cfg : ConfigClass
3838
An instance of the config object.
3939
"""
40-
field_names = [field.name for field in dataclasses.fields(cls)]
40+
field_names = [field.name for field in dataclasses.fields(cls)] # type: ignore[arg-type]
4141
fields = {k: v for k, v in source.items() if k in field_names}
4242
kwargs = {k: v for k, v in source.items() if k not in field_names}
43-
return cls(**fields, kwargs=kwargs)
43+
return cls(**fields, kwargs=kwargs) # type: ignore[call-arg]
4444

4545
@classmethod
4646
def from_file(cls: Type[ConfigClass], source: Path) -> ConfigClass:

tests/python/parameter/test_hf_torch_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# pylint: disable=missing-docstring
22
import logging
33
from pathlib import Path
4+
from typing import Union
45

56
import pytest
67
from mlc_chat.compiler.model.llama import LlamaConfig
@@ -24,7 +25,7 @@
2425
"./dist/models/Llama-2-70b-hf",
2526
],
2627
)
27-
def test_load_llama(base_path: str):
28+
def test_load_llama(base_path: Union[str, Path]):
2829
base_path = Path(base_path)
2930
path_config = base_path / "config.json"
3031
path_params = base_path / "pytorch_model.bin.index.json"

0 commit comments

Comments
 (0)