@@ -72,9 +72,26 @@ class OperatorConfig(NamedTuple):
72
72
valid_func_list : List [Callable ] = []
73
73
74
74
75
+ class TorchBaseConfig (BaseConfig ):
76
+ # re-write func _get_op_name_op_type_config to fallback op_type with string
77
+ # because there are some special op_types for IPEX backend: `Linear&Relu`, `Linear&add`, ...
78
+ def _get_op_name_op_type_config (self ):
79
+ op_type_config_dict = dict ()
80
+ op_name_config_dict = dict ()
81
+ for name , config in self .local_config .items ():
82
+ if self ._is_op_type (name ):
83
+ # Convert the Callable to String.
84
+ new_name = self ._op_type_to_str (name )
85
+ op_type_config_dict [new_name ] = config
86
+ else :
87
+ op_name_config_dict [name ] = config
88
+ op_type_config_dict [name ] = config
89
+ return op_type_config_dict , op_name_config_dict
90
+
91
+
75
92
######################## RNT Config ###############################
76
93
@register_config (framework_name = FRAMEWORK_NAME , algo_name = RTN , priority = PRIORITY_RTN )
77
- class RTNConfig (BaseConfig ):
94
+ class RTNConfig (TorchBaseConfig ):
78
95
"""Config class for round-to-nearest weight-only quantization."""
79
96
80
97
name = RTN
@@ -242,7 +259,7 @@ def get_default_double_quant_config(type="BNB_NF4"):
242
259
243
260
######################## GPTQ Config ###############################
244
261
@register_config (framework_name = FRAMEWORK_NAME , algo_name = GPTQ , priority = PRIORITY_GPTQ )
245
- class GPTQConfig (BaseConfig ):
262
+ class GPTQConfig (TorchBaseConfig ):
246
263
"""Config class for GPTQ.
247
264
248
265
GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers.
@@ -397,7 +414,7 @@ def get_default_gptq_config(processor_type: Optional[Union[str, torch_utils.Proc
397
414
398
415
######################## AWQ Config ###############################
399
416
@register_config (framework_name = FRAMEWORK_NAME , algo_name = AWQ , priority = PRIORITY_AWQ )
400
- class AWQConfig (BaseConfig ):
417
+ class AWQConfig (TorchBaseConfig ):
401
418
"""Config class for AWQ.
402
419
403
420
AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
@@ -539,7 +556,7 @@ def get_default_awq_config() -> AWQConfig:
539
556
540
557
######################## TEQ Config ###############################
541
558
@register_config (framework_name = FRAMEWORK_NAME , algo_name = TEQ , priority = PRIORITY_TEQ )
542
- class TEQConfig (BaseConfig ):
559
+ class TEQConfig (TorchBaseConfig ):
543
560
"""Config class for TEQ.
544
561
545
562
TEQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
@@ -677,7 +694,7 @@ def get_default_teq_config() -> TEQConfig:
677
694
678
695
######################## AUTOROUND Config ###############################
679
696
@register_config (framework_name = FRAMEWORK_NAME , algo_name = AUTOROUND , priority = PRIORITY_AUTOROUND )
680
- class AutoRoundConfig (BaseConfig ):
697
+ class AutoRoundConfig (TorchBaseConfig ):
681
698
"""Config class for AUTOROUND.
682
699
683
700
AUTOROUND: Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs.
@@ -827,7 +844,7 @@ def get_default_AutoRound_config(processor_type: Optional[Union[str, torch_utils
827
844
828
845
######################## MX Config ###############################
829
846
@register_config (framework_name = FRAMEWORK_NAME , algo_name = MX_QUANT )
830
- class MXQuantConfig (BaseConfig ):
847
+ class MXQuantConfig (TorchBaseConfig ):
831
848
"""Config class for MX quantization."""
832
849
833
850
supported_configs : List [OperatorConfig ] = []
@@ -940,7 +957,7 @@ def get_default_mx_config() -> MXQuantConfig:
940
957
941
958
######################## Dynamic Quant Config ###############################
942
959
@register_config (framework_name = FRAMEWORK_NAME , algo_name = PT2E_DYNAMIC_QUANT )
943
- class DynamicQuantConfig (BaseConfig ):
960
+ class DynamicQuantConfig (TorchBaseConfig ):
944
961
"""Config class for dynamic quantization."""
945
962
946
963
name = PT2E_DYNAMIC_QUANT
@@ -1014,7 +1031,7 @@ def get_default_dynamic_config() -> DynamicQuantConfig:
1014
1031
1015
1032
######################## Static Quant Config ###############################
1016
1033
@register_config (framework_name = FRAMEWORK_NAME , algo_name = STATIC_QUANT )
1017
- class StaticQuantConfig (BaseConfig ):
1034
+ class StaticQuantConfig (TorchBaseConfig ):
1018
1035
"""Config class for static quantization."""
1019
1036
1020
1037
name = STATIC_QUANT
@@ -1103,7 +1120,7 @@ def get_default_static_config() -> StaticQuantConfig:
1103
1120
1104
1121
######################## Smooth Quant Config ###############################
1105
1122
@register_config (framework_name = FRAMEWORK_NAME , algo_name = SMOOTH_QUANT )
1106
- class SmoothQuantConfig (BaseConfig ):
1123
+ class SmoothQuantConfig (TorchBaseConfig ):
1107
1124
"""Config class for smooth quantization."""
1108
1125
1109
1126
name = SMOOTH_QUANT
@@ -1217,7 +1234,7 @@ def get_default_sq_config() -> SmoothQuantConfig:
1217
1234
1218
1235
######################## HQQ Config ###############################
1219
1236
@register_config (framework_name = FRAMEWORK_NAME , algo_name = HQQ , priority = PRIORITY_HQQ )
1220
- class HQQConfig (BaseConfig ):
1237
+ class HQQConfig (TorchBaseConfig ):
1221
1238
# Half-Quadratic Quantization (HQQ), more details:
1222
1239
# Blog: https://mobiusml.github.io/hqq_blog/
1223
1240
# Code: https://github.com/mobiusml/hqq
@@ -1298,7 +1315,7 @@ def get_default_hqq_config() -> HQQConfig:
1298
1315
1299
1316
######################## FP8 Config ###############################
1300
1317
@register_config (framework_name = FRAMEWORK_NAME , algo_name = FP8_QUANT )
1301
- class FP8Config (BaseConfig ):
1318
+ class FP8Config (TorchBaseConfig ):
1302
1319
"""Config class for FP8 quantization."""
1303
1320
1304
1321
name = FP8_QUANT
@@ -1393,7 +1410,7 @@ def get_default_fp8_config_set() -> FP8Config:
1393
1410
1394
1411
######################## MixPrecision Config ###############################
1395
1412
@register_config (framework_name = FRAMEWORK_NAME , algo_name = MIX_PRECISION )
1396
- class MixPrecisionConfig (BaseConfig ):
1413
+ class MixPrecisionConfig (TorchBaseConfig ):
1397
1414
"""Config class for mix-precision."""
1398
1415
1399
1416
name = MIX_PRECISION
0 commit comments