@@ -1233,36 +1233,16 @@ def get_default_hqq_config() -> HQQConfig:
1233
1233
1234
1234
1235
1235
######################## FP8 Quant Config ###############################
1236
- # refer to habana_quantization_toolkit/_core/common.py
1237
- FP8_WHITE_LIST = [
1238
- "Matmul" ,
1239
- "Linear" ,
1240
- "FalconLinear" ,
1241
- "KVCache" ,
1242
- "Conv2d" ,
1243
- "LoRACompatibleLinear" ,
1244
- "LoRACompatibleConv" ,
1245
- "Softmax" ,
1246
- "ModuleFusedSDPA" ,
1247
- ]
1248
- if importlib .util .find_spec ("deepspeed" ):
1249
- FP8_WHITE_LIST .extend (["LinearLayer" , "LinearAllreduce" , "ScopedLinearAllReduce" , "LmHeadLinearAllreduce" ])
1250
1236
1237
+ from ..algorithms .fp8_quant ._core .common import mod_default_dict
1238
+ FP8_WHITE_LIST = mod_default_dict .keys ()
1251
1239
1252
1240
@register_config (framework_name = FRAMEWORK_NAME , algo_name = FP8_QUANT )
1253
1241
class FP8Config (BaseConfig ):
1254
1242
"""Config class for FP8 quantization."""
1255
1243
1256
1244
name = FP8_QUANT
1257
1245
1258
- # tunable params
1259
- params_list = [
1260
- "fp8_config" ,
1261
- "scale_method" ,
1262
- "observer" ,
1263
- "measure_exclude" ,
1264
- ]
1265
-
1266
1246
def __init__ (
1267
1247
self ,
1268
1248
dump_stats_path : str = "./hqt_output/measure" ,
@@ -1328,39 +1308,11 @@ def save_temp_json_file(self):
1328
1308
def get_config_set_for_tuning (cls ) -> Union [None , "FP8Config" , List ["FP8Config" ]]:
1329
1309
# just a simple example here
1330
1310
# usually write parameter combinations that are more suitable to tune based on experience.
1331
- return FP8Config (
1332
- fp8_config = ["E4M3" , "E5M2" ], scale_method = ["without_scale" , "maxabs_hw" ], measure_exclude = ["NONE" , "OUTPUT" ]
1333
- )
1311
+ return FP8Config ()
1334
1312
1335
1313
@classmethod
1336
- def register_supported_configs (cls ):
1337
- """Add all supported configs."""
1338
- supported_configs = []
1339
- linear_rtn_config = FP8Config (
1340
- mode = ["AUTO" , "MEASURE" , "QUANTIZE" ],
1341
- fp8_config = ["E4M3" , "E5M2" ],
1342
- scale_method = [
1343
- "without_scale" ,
1344
- "unit_scale" ,
1345
- "max" ,
1346
- "maxabs_hw" ,
1347
- "maxabs_pow2" ,
1348
- "maxabs_hw_opt_weight" ,
1349
- "maxabs_pow2_opt_weight" ,
1350
- "smoothquant_weights_output_channel_maxabs_pow2" ,
1351
- "weaksmoothquant_weights_output_channel_maxabs_pow2" ,
1352
- "act_maxabs_hw_weights_pcs_maxabs_pow2" ,
1353
- "act_maxabs_hw_weights_pcs_opt_pow2" ,
1354
- "act_maxabs_pow2_weights_pcs_maxabs_pow2" ,
1355
- "act_maxabs_pow2_weights_pcs_opt_pow2" ,
1356
- "smoothquant_opt" ,
1357
- ],
1358
- observer = ["shape" , "maxabs" , "maxabs_per_channel" , "save" ],
1359
- measure_exclude = ["NONE" , "OUTPUT" , "INPUT" , "ALL" ],
1360
- )
1361
- operators = list (FP8_WHITE_LIST )
1362
- supported_configs .append (OperatorConfig (config = linear_rtn_config , operators = operators ))
1363
- cls .supported_configs = supported_configs
1314
+ def register_supported_configs (cls ) -> List :
1315
+ pass
1364
1316
1365
1317
@staticmethod
1366
1318
def get_model_info (model : torch .nn .Module ) -> List [Tuple [str , Callable ]]:
0 commit comments