24
24
import torch
25
25
26
26
from neural_compressor .common .base_config import BaseConfig , config_registry , register_config
27
- from neural_compressor .common .utility import (
28
- DEFAULT_WHITE_LIST ,
29
- FP8_QUANT ,
30
- GPTQ ,
31
- OP_NAME_OR_MODULE_TYPE ,
32
- RTN_WEIGHT_ONLY_QUANT ,
33
- )
27
+ from neural_compressor .common .utility import DEFAULT_WHITE_LIST , FP8_QUANT , GPTQ , OP_NAME_OR_MODULE_TYPE , RTN
34
28
from neural_compressor .torch .utils .constants import PRIORITY_GPTQ , PRIORITY_RTN
35
29
from neural_compressor .torch .utils .utility import is_hpex_avaliable , logger
36
30
@@ -60,8 +54,8 @@ class OperatorConfig(NamedTuple):
60
54
######################## RNT Config ###############################
61
55
62
56
63
- @register_config (framework_name = FRAMEWORK_NAME , algo_name = RTN_WEIGHT_ONLY_QUANT , priority = PRIORITY_RTN )
64
- class RTNWeightQuantConfig (BaseConfig ):
57
+ @register_config (framework_name = FRAMEWORK_NAME , algo_name = RTN , priority = PRIORITY_RTN )
58
+ class RTNConfig (BaseConfig ):
65
59
"""Config class for round-to-nearest weight-only quantization."""
66
60
67
61
supported_configs : List [OperatorConfig ] = []
@@ -80,7 +74,7 @@ class RTNWeightQuantConfig(BaseConfig):
80
74
"double_quant_sym" ,
81
75
"double_quant_group_size" ,
82
76
]
83
- name = RTN_WEIGHT_ONLY_QUANT
77
+ name = RTN
84
78
85
79
def __init__ (
86
80
self ,
@@ -137,12 +131,12 @@ def to_dict(self):
137
131
138
132
@classmethod
139
133
def from_dict (cls , config_dict ):
140
- return super (RTNWeightQuantConfig , cls ).from_dict (config_dict = config_dict , str2operator = str2operator )
134
+ return super (RTNConfig , cls ).from_dict (config_dict = config_dict , str2operator = str2operator )
141
135
142
136
@classmethod
143
137
def register_supported_configs (cls ) -> List [OperatorConfig ]:
144
138
supported_configs = []
145
- linear_rtn_config = RTNWeightQuantConfig (
139
+ linear_rtn_config = RTNConfig (
146
140
weight_dtype = ["int" , "int8" , "int4" , "nf4" , "fp4" , "fp4_e2m1_bnb" , "fp4_e2m1" ],
147
141
weight_bits = [4 , 1 , 2 , 3 , 5 , 6 , 7 , 8 ],
148
142
weight_group_size = [32 , - 1 , 1 , 4 , 8 , 16 , 64 , 128 , 256 , 512 , 1024 ],
@@ -173,16 +167,16 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:
173
167
174
168
175
169
# TODO(Yi) run `register_supported_configs` for all registered config.
176
- RTNWeightQuantConfig .register_supported_configs ()
170
+ RTNConfig .register_supported_configs ()
177
171
178
172
179
- def get_default_rtn_config () -> RTNWeightQuantConfig :
173
+ def get_default_rtn_config () -> RTNConfig :
180
174
"""Generate the default rtn config.
181
175
182
176
Returns:
183
177
the default rtn config.
184
178
"""
185
- return RTNWeightQuantConfig ()
179
+ return RTNConfig ()
186
180
187
181
188
182
######################## GPTQ Config ###############################
0 commit comments