1
1
import fnmatch
2
- import re
3
2
from typing import Any , Dict , List , Optional , cast
4
3
5
4
import torch
@@ -122,6 +121,12 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
122
121
for q_config in q_configs :
123
122
q_config ["output_tensors" ] = None
124
123
124
+ # In case q_proj output is also quantized, remove the configuration
125
+ # to keep qkv consistency.
126
+ q_proj_q_config = cast (Dict [str , Any ],
127
+ layer_quant_config .get ("*q_proj" ))
128
+ q_proj_q_config ["output_tensors" ] = None
129
+
125
130
return cls (quant_config = config ,
126
131
kv_cache_group = kv_cache_group ,
127
132
kv_cache_config = kv_cache_config ,
@@ -148,6 +153,19 @@ def _check_scheme_supported(self,
148
153
else :
149
154
return False
150
155
156
+ def is_fp8_w8a8 (self ) -> bool :
157
+ # Returns True if all quantized layers in model are fp8 w8a8
158
+ global_quant_config = cast (
159
+ Dict [str , Any ], self .quant_config .get ("global_quant_config" ))
160
+ layer_quant_configs = cast (Dict [str , Any ],
161
+ self .quant_config .get ("layer_quant_config" ))
162
+ for config in (global_quant_config , * layer_quant_configs .values ()):
163
+ weight_config = cast (Dict [str , Any ], config .get ("weight" ))
164
+ input_config = cast (Dict [str , Any ], config .get ("input_tensors" ))
165
+ if not self ._is_fp8_w8a8 (weight_config , input_config ):
166
+ return False
167
+ return True
168
+
151
169
def _is_fp8_w8a8 (self , weight_quant : Optional [Dict [str , Any ]],
152
170
input_quant : Optional [Dict [str , Any ]]) -> bool :
153
171
# Confirm weights and input quantized.
@@ -286,25 +304,14 @@ def get_cache_scale(self, name: str) -> Optional[str]:
286
304
:param name: param name
287
305
:return: matching param name for KV cache scale in vLLM
288
306
"""
289
- if self .kv_cache_group is None or len (self .kv_cache_group ) == 0 :
290
- return None
291
-
292
- kv_proj_names = [
293
- re .split (r"[*.]" , kv_cache )[- 1 ] for kv_cache in self .kv_cache_group
294
- ]
295
- if name .endswith (".output_scale" ):
296
- if len (kv_proj_names ) == 1 and kv_proj_names [0 ] in name :
297
- kv_output_scale_name = "." + kv_proj_names [0 ] + ".output_scale"
298
- return name .replace (kv_output_scale_name , ".attn.k_scale" )
299
-
300
- elif len (kv_proj_names ) == 2 :
301
- for kv_proj_name in kv_proj_names :
302
- if kv_proj_name in name and kv_proj_name == "k_proj" :
303
- return name .replace (".k_proj.output_scale" ,
304
- ".attn.k_scale" )
305
- elif kv_proj_name in name and kv_proj_name == "v_proj" :
306
- return name .replace (".v_proj.output_scale" ,
307
- ".attn.v_scale" )
307
+ if name .endswith (".output_scale" ) and ".k_proj" in name :
308
+ return name .replace (".k_proj.output_scale" , ".attn.k_scale" )
309
+ if name .endswith (".output_scale" ) and ".v_proj" in name :
310
+ return name .replace (".v_proj.output_scale" , ".attn.v_scale" )
311
+ if name .endswith (".output_scale" ) and ".q_proj" in name :
312
+ return name .replace (".q_proj.output_scale" , ".attn.q_scale" )
313
+ if name .endswith ("self_attn.prob_output_scale" ):
314
+ return name .replace (".prob_output_scale" , ".attn.prob_scale" )
308
315
309
316
# If no matches, return None
310
317
return None
0 commit comments