Skip to content

Commit b42b018

Browse files
committed
[SW-195483] Remove hard coded strings from FP8 config in INC
Change-Id: I1f58b74ab07eda93739b4e6c8be5041ac2beb714
1 parent c6af377 commit b42b018

File tree

1 file changed

+5
-53
lines changed
  • neural_compressor/torch/quantization

1 file changed

+5
-53
lines changed

neural_compressor/torch/quantization/config.py

Lines changed: 5 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,36 +1233,16 @@ def get_default_hqq_config() -> HQQConfig:
12331233

12341234

12351235
######################## FP8 Quant Config ###############################
1236-
# refer to habana_quantization_toolkit/_core/common.py
1237-
FP8_WHITE_LIST = [
1238-
"Matmul",
1239-
"Linear",
1240-
"FalconLinear",
1241-
"KVCache",
1242-
"Conv2d",
1243-
"LoRACompatibleLinear",
1244-
"LoRACompatibleConv",
1245-
"Softmax",
1246-
"ModuleFusedSDPA",
1247-
]
1248-
if importlib.util.find_spec("deepspeed"):
1249-
FP8_WHITE_LIST.extend(["LinearLayer", "LinearAllreduce", "ScopedLinearAllReduce", "LmHeadLinearAllreduce"])
12501236

1237+
from ..algorithms.fp8_quant._core.common import mod_default_dict
1238+
FP8_WHITE_LIST = mod_default_dict.keys()
12511239

12521240
@register_config(framework_name=FRAMEWORK_NAME, algo_name=FP8_QUANT)
12531241
class FP8Config(BaseConfig):
12541242
"""Config class for FP8 quantization."""
12551243

12561244
name = FP8_QUANT
12571245

1258-
# tunable params
1259-
params_list = [
1260-
"fp8_config",
1261-
"scale_method",
1262-
"observer",
1263-
"measure_exclude",
1264-
]
1265-
12661246
def __init__(
12671247
self,
12681248
dump_stats_path: str = "./hqt_output/measure",
@@ -1328,39 +1308,11 @@ def save_temp_json_file(self):
13281308
def get_config_set_for_tuning(cls) -> Union[None, "FP8Config", List["FP8Config"]]:
13291309
# just a simple example here
13301310
# usually write parameter combinations that are more suitable to tune based on experience.
1331-
return FP8Config(
1332-
fp8_config=["E4M3", "E5M2"], scale_method=["without_scale", "maxabs_hw"], measure_exclude=["NONE", "OUTPUT"]
1333-
)
1311+
return FP8Config()
13341312

13351313
@classmethod
1336-
def register_supported_configs(cls):
1337-
"""Add all supported configs."""
1338-
supported_configs = []
1339-
linear_rtn_config = FP8Config(
1340-
mode=["AUTO", "MEASURE", "QUANTIZE"],
1341-
fp8_config=["E4M3", "E5M2"],
1342-
scale_method=[
1343-
"without_scale",
1344-
"unit_scale",
1345-
"max",
1346-
"maxabs_hw",
1347-
"maxabs_pow2",
1348-
"maxabs_hw_opt_weight",
1349-
"maxabs_pow2_opt_weight",
1350-
"smoothquant_weights_output_channel_maxabs_pow2",
1351-
"weaksmoothquant_weights_output_channel_maxabs_pow2",
1352-
"act_maxabs_hw_weights_pcs_maxabs_pow2",
1353-
"act_maxabs_hw_weights_pcs_opt_pow2",
1354-
"act_maxabs_pow2_weights_pcs_maxabs_pow2",
1355-
"act_maxabs_pow2_weights_pcs_opt_pow2",
1356-
"smoothquant_opt",
1357-
],
1358-
observer=["shape", "maxabs", "maxabs_per_channel", "save"],
1359-
measure_exclude=["NONE", "OUTPUT", "INPUT", "ALL"],
1360-
)
1361-
operators = list(FP8_WHITE_LIST)
1362-
supported_configs.append(OperatorConfig(config=linear_rtn_config, operators=operators))
1363-
cls.supported_configs = supported_configs
1314+
def register_supported_configs(cls) -> List:
1315+
pass
13641316

13651317
@staticmethod
13661318
def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]:

0 commit comments

Comments
 (0)