1
1
"""A centralized registry of all existing model architures and their configurations."""
2
2
import dataclasses
3
- from typing import Any , Callable , Dict
3
+ from typing import Any , Callable , Dict , Tuple
4
4
5
5
from tvm .relax .frontend import nn
6
6
7
7
from ..parameter import ExternMapping , QuantizeMapping
8
- from ..quantization .quantization import QuantizeConfig
9
- from . import llama_config , llama_model , llama_parameter
8
+ from ..quantization .quantization import Quantization
9
+ from . import llama_config , llama_model , llama_parameter , llama_quantization
10
10
11
11
ModelConfig = Any
12
12
"""A ModelConfig is an object that represents a model architecture. It is required to have
@@ -16,8 +16,8 @@ def from_file(cls, path: Path) -> ModelConfig:
16
16
...
17
17
"""
18
18
19
- FuncGetExternMap = Callable [[ModelConfig , QuantizeConfig ], ExternMapping ]
20
- FuncGetQuantMap = Callable [[ModelConfig , QuantizeConfig ], QuantizeMapping ]
19
+ FuncGetExternMap = Callable [[ModelConfig , Quantization ], ExternMapping ]
20
+ FuncQuantization = Callable [[ModelConfig , Quantization ], Tuple [ nn . Module , QuantizeMapping ] ]
21
21
22
22
23
23
@dataclasses .dataclass
@@ -38,15 +38,16 @@ class Model:
38
38
source : Dict[str, FuncGetExternMap]
39
39
A dictionary that maps the name of a source format to parameter mapping.
40
40
41
- quantize: Dict[str, FuncGetQuantMap]
42
- A dictionary that maps the name of a quantization method to quantization mapping.
41
+ quantize: Dict[str, FuncQuantization]
42
+ A dictionary that maps the name of a quantization method to quantized model and the
43
+ quantization parameter mapping.
43
44
"""
44
45
45
46
name : str
46
47
config : ModelConfig
47
48
model : Callable [[ModelConfig ], nn .Module ]
48
49
source : Dict [str , FuncGetExternMap ]
49
- quantize : Dict [str , FuncGetQuantMap ]
50
+ quantize : Dict [str , FuncQuantization ]
50
51
51
52
52
53
MODELS : Dict [str , Model ] = {
@@ -58,7 +59,9 @@ class Model:
58
59
"huggingface-torch" : llama_parameter .huggingface ,
59
60
"huggingface-safetensor" : llama_parameter .huggingface ,
60
61
},
61
- quantize = {},
62
+ quantize = {
63
+ "group-quant" : llama_quantization .group_quant ,
64
+ },
62
65
)
63
66
}
64
67
0 commit comments