22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44from fractions import Fraction
5- from typing import Any , Optional , Union
5+ from typing import TYPE_CHECKING , Any , Optional , Union
66
77import torch
88
1616from vllm .platforms import current_platform
1717from vllm .scalar_type import scalar_types
1818
19+ if TYPE_CHECKING :
20+ from vllm .model_executor .models .utils import WeightsMapper
21+
1922logger = init_logger (__name__ )
2023
2124
@@ -28,7 +31,13 @@ class AutoRoundConfig(QuantizationConfig):
2831 SUPPORTED_DTYPES = {"int" }
2932 SUPPORTED_FORMATS = {"auto_round:auto_gptq" , "auto_round:auto_awq" }
3033 SUPPORTED_BACKENDS = {
31- "auto" , "gptq" , "gptq:marlin" , "awq" , "awq:marlin" , "marlin" , "ipex"
34+ "auto" ,
35+ "gptq" ,
36+ "gptq:marlin" ,
37+ "awq" ,
38+ "awq:marlin" ,
39+ "marlin" ,
40+ "ipex" ,
3241 }
3342
3443 def __init__ (
@@ -109,26 +118,70 @@ def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
109118 )
110119
111120 def get_layer_config (self , layer , layer_name : str ):
112- # Priority: extra_config > block_name_to_quantize > type fallback
121+
122+ def get_config (name : str , quantized : bool = True ):
123+ cfg = self .extra_config .get (name , {}) if self .extra_config else {}
124+ return (
125+ cfg .get ("bits" , self .weight_bits if quantized else 16 ),
126+ cfg .get ("group_size" , self .group_size if quantized else - 1 ),
127+ cfg .get ("sym" , self .sym if quantized else True ),
128+ )
129+
130+ # 1. Exact match from config
113131 if self .extra_config and layer_name in self .extra_config :
114- cfg = self .extra_config [layer_name ]
115- return cfg .get ("bits" , self .weight_bits ), cfg .get (
116- "group_size" , self .group_size ), cfg .get ("sym" , self .sym )
132+ return get_config (layer_name )
117133
118- quantized = True
134+ # 2. Determine whether layer should be quantized
135+ quantized = not isinstance (layer , ParallelLMHead )
119136 if self .block_name_to_quantize :
120137 quantized = any (
121138 layer_name .startswith (name )
122139 for name in self .block_name_to_quantize )
123- elif isinstance (layer , ParallelLMHead ):
124- quantized = False
125140
126- return (self .weight_bits , self .group_size ,
127- self .sym ) if quantized else (16 , - 1 , True )
141+ # 3. Handle fused MoE
142+ if self .extra_config and "fusedmoe" in layer .__class__ .__name__ .lower (
143+ ):
144+ moe_configs = [
145+ get_config (name , quantized ) for name in self .extra_config
146+ if name .startswith (layer_name )
147+ ]
148+ if moe_configs :
149+ if len (set (moe_configs )) == 1 :
150+ return moe_configs [0 ]
151+ raise ValueError (f"Fused MoE layer '{ layer_name } ' requires "
152+ f"consistent quant config for all sub-layers" )
153+
154+ # 4. Handle fused QKV or other patterns
155+ if self .extra_config :
156+ for fusion_key , sub_keys in self .packed_modules_mapping .items ():
157+ if fusion_key in layer_name and layer_name .count (
158+ fusion_key ) == 1 :
159+ sub_names = [
160+ layer_name .replace (fusion_key , sub_key )
161+ for sub_key in sub_keys
162+ ]
163+ sub_configs = [
164+ get_config (name , quantized ) for name in sub_names
165+ ]
166+ if len (set (sub_configs )) == 1 :
167+ return sub_configs [0 ]
168+ raise ValueError (
169+ f"Fused module '{ layer_name } ' requires "
170+ f"consistent quant config for { sub_names } " )
171+
172+ # 5. Fallback
173+ return get_config (layer_name , quantized )
128174
129175 def check_quantized (self , weight_bits : int ) -> bool :
130176 return weight_bits < 16
131177
178+ def apply_vllm_mapper (self , hf_to_vllm_mapper : "WeightsMapper" ):
179+ if self .block_name_to_quantize is not None :
180+ self .block_name_to_quantize = hf_to_vllm_mapper .apply_list (
181+ self .block_name_to_quantize )
182+ if self .extra_config is not None :
183+ self .extra_config = hf_to_vllm_mapper .apply_dict (self .extra_config )
184+
132185 def apply_awq_quant_layer (self , layer , prefix : str , backend : str = "auto" ):
133186 from vllm .model_executor .layers .fused_moe import FusedMoE
134187 from vllm .model_executor .layers .quantization .utils .marlin_utils import (
@@ -141,9 +194,14 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
141194 else :
142195 return None
143196
144- logger .debug ("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s" ,
145- prefix , layer .__class__ .__name__ , weight_bits , group_size ,
146- sym )
197+ logger .debug (
198+ "[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s" ,
199+ prefix ,
200+ layer .__class__ .__name__ ,
201+ weight_bits ,
202+ group_size ,
203+ sym ,
204+ )
147205 if backend == "auto" or "marlin" in backend :
148206 AWQ_TYPE_MAP = {
149207 4 : scalar_types .uint4 ,
@@ -162,15 +220,19 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
162220 if use_marlin :
163221 from vllm .model_executor .layers .quantization .awq_marlin import (
164222 AWQMarlinConfig , AWQMarlinLinearMethod , AWQMoEMethod )
165- quant_args_marlin = AWQMarlinConfig (weight_bits = weight_bits ,
166- group_size = group_size ,
167- zero_point = not sym ,
168- lm_head_quantized = False ,
169- full_config = {},
170- modules_to_not_convert = [])
223+
224+ quant_args_marlin = AWQMarlinConfig (
225+ weight_bits = weight_bits ,
226+ group_size = group_size ,
227+ zero_point = not sym ,
228+ lm_head_quantized = False ,
229+ full_config = {},
230+ modules_to_not_convert = [],
231+ )
171232 else :
172233 from vllm .model_executor .layers .quantization .awq import (
173234 AWQConfig , AWQLinearMethod )
235+
174236 quant_args = AWQConfig (
175237 weight_bits = weight_bits ,
176238 group_size = group_size ,
@@ -182,6 +244,7 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
182244 return AWQMoEMethod (quant_args_marlin )
183245 from vllm .model_executor .layers .quantization .moe_wna16 import (
184246 MoeWNA16Config )
247+
185248 config = {
186249 "quant_method" : "awq" ,
187250 "bits" : weight_bits ,
@@ -206,26 +269,32 @@ def apply_gptq_quant_layer(self,
206269 from vllm .model_executor .layers .fused_moe import FusedMoE
207270 from vllm .model_executor .layers .quantization .utils .marlin_utils import (
208271 check_marlin_supported , check_moe_marlin_supports_layer )
272+
209273 weight_bits , group_size , sym = self .get_layer_config (layer , prefix )
210274 if not self .check_quantized (weight_bits ):
211275 if isinstance (layer , (LinearBase , ParallelLMHead )):
212276 return UnquantizedLinearMethod ()
213277 else :
214278 return None
215279
216- logger .debug ("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s" ,
217- prefix , layer .__class__ .__name__ , weight_bits , group_size ,
218- sym )
280+ logger .debug (
281+ "[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s" ,
282+ prefix ,
283+ layer .__class__ .__name__ ,
284+ weight_bits ,
285+ group_size ,
286+ sym ,
287+ )
219288 if backend == "auto" or "marlin" in backend :
220289 GPTQ_TYPE_MAP = {
221290 (4 , True ): scalar_types .uint4b8 ,
222291 (8 , True ): scalar_types .uint8b128 ,
223292 }
224- use_marlin = (( weight_bits , sym ) in GPTQ_TYPE_MAP
225- and check_marlin_supported (
293+ use_marlin = (weight_bits ,
294+ sym ) in GPTQ_TYPE_MAP and check_marlin_supported (
226295 GPTQ_TYPE_MAP [(weight_bits , sym )],
227296 group_size ,
228- has_zp = not sym ))
297+ has_zp = not sym )
229298 if isinstance (layer , FusedMoE ):
230299 use_marlin = use_marlin and check_moe_marlin_supports_layer (
231300 layer , group_size )
@@ -234,26 +303,33 @@ def apply_gptq_quant_layer(self,
234303 if use_marlin :
235304 from vllm .model_executor .layers .quantization .gptq_marlin import (
236305 GPTQMarlinConfig , GPTQMarlinLinearMethod , GPTQMarlinMoEMethod )
237- quant_args_marlin = GPTQMarlinConfig (weight_bits = weight_bits ,
238- group_size = group_size ,
239- is_sym = sym ,
240- lm_head_quantized = False ,
241- desc_act = False ,
242- dynamic = {},
243- full_config = {})
306+
307+ quant_args_marlin = GPTQMarlinConfig (
308+ weight_bits = weight_bits ,
309+ group_size = group_size ,
310+ is_sym = sym ,
311+ lm_head_quantized = False ,
312+ desc_act = False ,
313+ dynamic = {},
314+ full_config = {},
315+ )
244316 else :
245317 from vllm .model_executor .layers .quantization .gptq import (
246318 GPTQConfig , GPTQLinearMethod )
247- quant_args = GPTQConfig (weight_bits = weight_bits ,
248- group_size = group_size ,
249- lm_head_quantized = False ,
250- desc_act = False ,
251- dynamic = {})
319+
320+ quant_args = GPTQConfig (
321+ weight_bits = weight_bits ,
322+ group_size = group_size ,
323+ lm_head_quantized = False ,
324+ desc_act = False ,
325+ dynamic = {},
326+ )
252327
253328 if isinstance (layer , FusedMoE ):
254329 if use_marlin :
255330 from vllm .model_executor .layers .quantization .moe_wna16 import (
256331 MoeWNA16Config )
332+
257333 config = {
258334 "quant_method" : "gptq" ,
259335 "bits" : weight_bits ,
@@ -282,6 +358,7 @@ def apply_ipex_quant_layer(self, layer, prefix: str):
282358 return None
283359 from vllm .model_executor .layers .quantization .ipex_quant import (
284360 IPEXAWQLinearMethod , IPEXConfig , IPEXGPTQLinearMethod )
361+
285362 if isinstance (layer , (LinearBase , ParallelLMHead )):
286363 if "awq" in self .packing_format :
287364 config = IPEXConfig (method = "awq" ,
0 commit comments