18
18
from __future__ import annotations
19
19
20
20
from enum import Enum
21
- from typing import Callable , Dict , List , NamedTuple , Union
21
+ from typing import Callable , Dict , List , NamedTuple , Optional , Union
22
22
23
23
import torch
24
24
25
25
from neural_compressor .common .base_config import BaseConfig , register_config , registered_configs
26
- from neural_compressor .common .utility import DUMMY_CONFIG , GPTQ , RTN_WEIGHT_ONLY_QUANT
26
+ from neural_compressor .common .utility import DEFAULT_WHITE_LIST , GPTQ , OP_NAME_OR_MODULE_TYPE , RTN_WEIGHT_ONLY_QUANT
27
27
28
28
FRAMEWORK_NAME = "torch"
29
29
@@ -87,6 +87,7 @@ def __init__(
87
87
double_quant_bits : int = 8 ,
88
88
double_quant_sym : bool = True ,
89
89
double_quant_group_size : int = 256 ,
90
+ white_list : Optional [List [OP_NAME_OR_MODULE_TYPE ]] = DEFAULT_WHITE_LIST ,
90
91
):
91
92
"""Init RTN weight-only quantization config.
92
93
@@ -105,7 +106,7 @@ def __init__(
105
106
double_quant_sym (bool): Indicates whether double_quant scale are symmetric, default is True.
106
107
double_quant_group_size (int): Size of double_quant groups, default is 32.
107
108
"""
108
- super ().__init__ ()
109
+ super ().__init__ (white_list = white_list )
109
110
self .weight_bits = weight_bits
110
111
self .weight_dtype = weight_dtype
111
112
self .weight_group_size = weight_group_size
@@ -119,6 +120,7 @@ def __init__(
119
120
self .double_quant_dtype = double_quant_dtype
120
121
self .double_quant_sym = double_quant_sym
121
122
self .double_quant_group_size = double_quant_group_size
123
+ self ._post_init ()
122
124
123
125
def to_dict (self ):
124
126
return super ().to_dict (params_list = self .params_list , operator2str = operator2str )
@@ -220,12 +222,13 @@ def __init__(
220
222
double_quant_bits : int = 8 ,
221
223
double_quant_sym : bool = True ,
222
224
double_quant_group_size : int = 256 ,
225
+ white_list : Optional [List [OP_NAME_OR_MODULE_TYPE ]] = DEFAULT_WHITE_LIST ,
223
226
):
224
227
"""Init GPTQ config.
225
228
226
229
Args:
227
230
"""
228
- super ().__init__ ()
231
+ super ().__init__ (white_list = white_list )
229
232
self .weight_dtype = weight_dtype
230
233
self .weight_bits = weight_bits
231
234
self .weight_group_size = weight_group_size
@@ -248,6 +251,7 @@ def __init__(
248
251
self .double_quant_dtype = double_quant_dtype
249
252
self .double_quant_sym = double_quant_sym
250
253
self .double_quant_group_size = double_quant_group_size
254
+ self ._post_init ()
251
255
252
256
def to_dict (self ):
253
257
return super ().to_dict (params_list = self .params_list , operator2str = operator2str )
@@ -281,65 +285,6 @@ def get_default_gptq_config() -> GPTQConfig:
281
285
return GPTQConfig ()
282
286
283
287
284
- ######################## Dummy Config ###############################
285
- # TODO (Yi) remove it after finishing the GPTQ config
286
- @register_config (framework_name = FRAMEWORK_NAME , algo_name = DUMMY_CONFIG )
287
- class DummyConfig (BaseConfig ):
288
- """Config class for round-to-nearest weight-only quantization."""
289
-
290
- supported_configs : List [OperatorConfig ] = []
291
- params_list = ["act_dtype" , "weight_dtype" , "dummy_attr" ]
292
- name = DUMMY_CONFIG
293
-
294
- def __init__ (
295
- self ,
296
- weight_dtype : str = "int" ,
297
- act_dtype : str = "fp32" ,
298
- dummy_attr : int = 0 ,
299
- ):
300
- """Init RTN weight-only quantization config.
301
-
302
- Args:
303
- act_dtype (str): Data type for activations, default is "fp32".
304
- weight_dtype (str): Data type for weights, default is "int".
305
- dummy_attr (int): Dummy attribute, default is 0.
306
- """
307
- super ().__init__ ()
308
- self .act_dtype = act_dtype
309
- self .weight_dtype = weight_dtype
310
- self .dummy_attr = dummy_attr
311
-
312
- def to_dict (self ):
313
- return super ().to_dict (params_list = self .params_list , operator2str = operator2str )
314
-
315
- @classmethod
316
- def from_dict (cls , config_dict ):
317
- return super (DummyConfig , cls ).from_dict (config_dict = config_dict , str2operator = str2operator )
318
-
319
- @classmethod
320
- def register_supported_configs (cls ) -> List [OperatorConfig ]:
321
- supported_configs = []
322
- linear_dummy_config = DummyConfig (
323
- act_dtype = ["fp32" ],
324
- weight_dtype = ["int4" , "int8" ],
325
- dummy_attr = [1 , 2 , 3 ],
326
- )
327
- operators = [torch .nn .Linear , torch .nn .functional .linear ]
328
- supported_configs .append (
329
- OperatorConfig (config = linear_dummy_config , operators = operators , backend = Backend .DEFAULT )
330
- )
331
- cls .supported_configs = supported_configs
332
-
333
-
334
- def get_default_dummy_config () -> DummyConfig :
335
- """Generate the default dummy config.
336
-
337
- Returns:
338
- the default dummy config.
339
- """
340
- return DummyConfig ()
341
-
342
-
343
288
##################### Algo Configs End ###################################
344
289
345
290
0 commit comments