Skip to content

Commit be42d03

Browse files
authored
implement TorchBaseConfig (#1911)
Signed-off-by: xin3he <[email protected]>
1 parent 7a4715c commit be42d03

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

neural_compressor/torch/quantization/config.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,26 @@ class OperatorConfig(NamedTuple):
7272
valid_func_list: List[Callable] = []
7373

7474

75+
class TorchBaseConfig(BaseConfig):
76+
# re-write func _get_op_name_op_type_config to fallback op_type with string
77+
# because there are some special op_types for IPEX backend: `Linear&Relu`, `Linear&add`, ...
78+
def _get_op_name_op_type_config(self):
79+
op_type_config_dict = dict()
80+
op_name_config_dict = dict()
81+
for name, config in self.local_config.items():
82+
if self._is_op_type(name):
83+
# Convert the Callable to String.
84+
new_name = self._op_type_to_str(name)
85+
op_type_config_dict[new_name] = config
86+
else:
87+
op_name_config_dict[name] = config
88+
op_type_config_dict[name] = config
89+
return op_type_config_dict, op_name_config_dict
90+
91+
7592
######################## RNT Config ###############################
7693
@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN, priority=PRIORITY_RTN)
77-
class RTNConfig(BaseConfig):
94+
class RTNConfig(TorchBaseConfig):
7895
"""Config class for round-to-nearest weight-only quantization."""
7996

8097
name = RTN
@@ -242,7 +259,7 @@ def get_default_double_quant_config(type="BNB_NF4"):
242259

243260
######################## GPTQ Config ###############################
244261
@register_config(framework_name=FRAMEWORK_NAME, algo_name=GPTQ, priority=PRIORITY_GPTQ)
245-
class GPTQConfig(BaseConfig):
262+
class GPTQConfig(TorchBaseConfig):
246263
"""Config class for GPTQ.
247264
248265
GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers.
@@ -397,7 +414,7 @@ def get_default_gptq_config(processor_type: Optional[Union[str, torch_utils.Proc
397414

398415
######################## AWQ Config ###############################
399416
@register_config(framework_name=FRAMEWORK_NAME, algo_name=AWQ, priority=PRIORITY_AWQ)
400-
class AWQConfig(BaseConfig):
417+
class AWQConfig(TorchBaseConfig):
401418
"""Config class for AWQ.
402419
403420
AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
@@ -539,7 +556,7 @@ def get_default_awq_config() -> AWQConfig:
539556

540557
######################## TEQ Config ###############################
541558
@register_config(framework_name=FRAMEWORK_NAME, algo_name=TEQ, priority=PRIORITY_TEQ)
542-
class TEQConfig(BaseConfig):
559+
class TEQConfig(TorchBaseConfig):
543560
"""Config class for TEQ.
544561
545562
TEQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
@@ -677,7 +694,7 @@ def get_default_teq_config() -> TEQConfig:
677694

678695
######################## AUTOROUND Config ###############################
679696
@register_config(framework_name=FRAMEWORK_NAME, algo_name=AUTOROUND, priority=PRIORITY_AUTOROUND)
680-
class AutoRoundConfig(BaseConfig):
697+
class AutoRoundConfig(TorchBaseConfig):
681698
"""Config class for AUTOROUND.
682699
683700
AUTOROUND: Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs.
@@ -827,7 +844,7 @@ def get_default_AutoRound_config(processor_type: Optional[Union[str, torch_utils
827844

828845
######################## MX Config ###############################
829846
@register_config(framework_name=FRAMEWORK_NAME, algo_name=MX_QUANT)
830-
class MXQuantConfig(BaseConfig):
847+
class MXQuantConfig(TorchBaseConfig):
831848
"""Config class for MX quantization."""
832849

833850
supported_configs: List[OperatorConfig] = []
@@ -940,7 +957,7 @@ def get_default_mx_config() -> MXQuantConfig:
940957

941958
######################## Dynamic Quant Config ###############################
942959
@register_config(framework_name=FRAMEWORK_NAME, algo_name=PT2E_DYNAMIC_QUANT)
943-
class DynamicQuantConfig(BaseConfig):
960+
class DynamicQuantConfig(TorchBaseConfig):
944961
"""Config class for dynamic quantization."""
945962

946963
name = PT2E_DYNAMIC_QUANT
@@ -1014,7 +1031,7 @@ def get_default_dynamic_config() -> DynamicQuantConfig:
10141031

10151032
######################## Static Quant Config ###############################
10161033
@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT)
1017-
class StaticQuantConfig(BaseConfig):
1034+
class StaticQuantConfig(TorchBaseConfig):
10181035
"""Config class for static quantization."""
10191036

10201037
name = STATIC_QUANT
@@ -1103,7 +1120,7 @@ def get_default_static_config() -> StaticQuantConfig:
11031120

11041121
######################## Smooth Quant Config ###############################
11051122
@register_config(framework_name=FRAMEWORK_NAME, algo_name=SMOOTH_QUANT)
1106-
class SmoothQuantConfig(BaseConfig):
1123+
class SmoothQuantConfig(TorchBaseConfig):
11071124
"""Config class for smooth quantization."""
11081125

11091126
name = SMOOTH_QUANT
@@ -1217,7 +1234,7 @@ def get_default_sq_config() -> SmoothQuantConfig:
12171234

12181235
######################## HQQ Config ###############################
12191236
@register_config(framework_name=FRAMEWORK_NAME, algo_name=HQQ, priority=PRIORITY_HQQ)
1220-
class HQQConfig(BaseConfig):
1237+
class HQQConfig(TorchBaseConfig):
12211238
# Half-Quadratic Quantization (HQQ), more details:
12221239
# Blog: https://mobiusml.github.io/hqq_blog/
12231240
# Code: https://github.com/mobiusml/hqq
@@ -1298,7 +1315,7 @@ def get_default_hqq_config() -> HQQConfig:
12981315

12991316
######################## FP8 Config ###############################
13001317
@register_config(framework_name=FRAMEWORK_NAME, algo_name=FP8_QUANT)
1301-
class FP8Config(BaseConfig):
1318+
class FP8Config(TorchBaseConfig):
13021319
"""Config class for FP8 quantization."""
13031320

13041321
name = FP8_QUANT
@@ -1393,7 +1410,7 @@ def get_default_fp8_config_set() -> FP8Config:
13931410

13941411
######################## MixPrecision Config ###############################
13951412
@register_config(framework_name=FRAMEWORK_NAME, algo_name=MIX_PRECISION)
1396-
class MixPrecisionConfig(BaseConfig):
1413+
class MixPrecisionConfig(TorchBaseConfig):
13971414
"""Config class for mix-precision."""
13981415

13991416
name = MIX_PRECISION

test/3x/torch/quantization/test_static_quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_static_quant_fallback(self):
7676
quant_config = get_default_static_config()
7777
example_inputs = self.input
7878
# fallback by op_type
79-
quant_config.set_local(torch.nn.Linear, StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
79+
quant_config.set_local([torch.nn.Linear, "Linear&add"], StaticQuantConfig(w_dtype="fp32", act_dtype="fp32"))
8080
prepared_model = prepare(fp32_model, quant_config=quant_config, example_inputs=example_inputs)
8181
run_fn(prepared_model)
8282
q_model = convert(prepared_model)

0 commit comments

Comments
 (0)