Skip to content

Commit fb9ce65

Browse files
wenhuach21epwalsh
authored andcommitted
[Bugfix]fix mixed bits and visual language model quantization in AutoRound (vllm-project#21802)
Signed-off-by: Wenhua Cheng <[email protected]>
1 parent c8841c3 commit fb9ce65

File tree

1 file changed

+115
-38
lines changed

1 file changed

+115
-38
lines changed

vllm/model_executor/layers/quantization/auto_round.py

Lines changed: 115 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from fractions import Fraction
5-
from typing import Any, Optional, Union
5+
from typing import TYPE_CHECKING, Any, Optional, Union
66

77
import torch
88

@@ -16,6 +16,9 @@
1616
from vllm.platforms import current_platform
1717
from vllm.scalar_type import scalar_types
1818

19+
if TYPE_CHECKING:
20+
from vllm.model_executor.models.utils import WeightsMapper
21+
1922
logger = init_logger(__name__)
2023

2124

@@ -28,7 +31,13 @@ class AutoRoundConfig(QuantizationConfig):
2831
SUPPORTED_DTYPES = {"int"}
2932
SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
3033
SUPPORTED_BACKENDS = {
31-
"auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin", "ipex"
34+
"auto",
35+
"gptq",
36+
"gptq:marlin",
37+
"awq",
38+
"awq:marlin",
39+
"marlin",
40+
"ipex",
3241
}
3342

3443
def __init__(
@@ -109,26 +118,70 @@ def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
109118
)
110119

111120
def get_layer_config(self, layer, layer_name: str):
112-
# Priority: extra_config > block_name_to_quantize > type fallback
121+
122+
def get_config(name: str, quantized: bool = True):
123+
cfg = self.extra_config.get(name, {}) if self.extra_config else {}
124+
return (
125+
cfg.get("bits", self.weight_bits if quantized else 16),
126+
cfg.get("group_size", self.group_size if quantized else -1),
127+
cfg.get("sym", self.sym if quantized else True),
128+
)
129+
130+
# 1. Exact match from config
113131
if self.extra_config and layer_name in self.extra_config:
114-
cfg = self.extra_config[layer_name]
115-
return cfg.get("bits", self.weight_bits), cfg.get(
116-
"group_size", self.group_size), cfg.get("sym", self.sym)
132+
return get_config(layer_name)
117133

118-
quantized = True
134+
# 2. Determine whether layer should be quantized
135+
quantized = not isinstance(layer, ParallelLMHead)
119136
if self.block_name_to_quantize:
120137
quantized = any(
121138
layer_name.startswith(name)
122139
for name in self.block_name_to_quantize)
123-
elif isinstance(layer, ParallelLMHead):
124-
quantized = False
125140

126-
return (self.weight_bits, self.group_size,
127-
self.sym) if quantized else (16, -1, True)
141+
# 3. Handle fused MoE
142+
if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower(
143+
):
144+
moe_configs = [
145+
get_config(name, quantized) for name in self.extra_config
146+
if name.startswith(layer_name)
147+
]
148+
if moe_configs:
149+
if len(set(moe_configs)) == 1:
150+
return moe_configs[0]
151+
raise ValueError(f"Fused MoE layer '{layer_name}' requires "
152+
f"consistent quant config for all sub-layers")
153+
154+
# 4. Handle fused QKV or other patterns
155+
if self.extra_config:
156+
for fusion_key, sub_keys in self.packed_modules_mapping.items():
157+
if fusion_key in layer_name and layer_name.count(
158+
fusion_key) == 1:
159+
sub_names = [
160+
layer_name.replace(fusion_key, sub_key)
161+
for sub_key in sub_keys
162+
]
163+
sub_configs = [
164+
get_config(name, quantized) for name in sub_names
165+
]
166+
if len(set(sub_configs)) == 1:
167+
return sub_configs[0]
168+
raise ValueError(
169+
f"Fused module '{layer_name}' requires "
170+
f"consistent quant config for {sub_names}")
171+
172+
# 5. Fallback
173+
return get_config(layer_name, quantized)
128174

129175
def check_quantized(self, weight_bits: int) -> bool:
130176
return weight_bits < 16
131177

178+
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
179+
if self.block_name_to_quantize is not None:
180+
self.block_name_to_quantize = hf_to_vllm_mapper.apply_list(
181+
self.block_name_to_quantize)
182+
if self.extra_config is not None:
183+
self.extra_config = hf_to_vllm_mapper.apply_dict(self.extra_config)
184+
132185
def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
133186
from vllm.model_executor.layers.fused_moe import FusedMoE
134187
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@@ -141,9 +194,14 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
141194
else:
142195
return None
143196

144-
logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
145-
prefix, layer.__class__.__name__, weight_bits, group_size,
146-
sym)
197+
logger.debug(
198+
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
199+
prefix,
200+
layer.__class__.__name__,
201+
weight_bits,
202+
group_size,
203+
sym,
204+
)
147205
if backend == "auto" or "marlin" in backend:
148206
AWQ_TYPE_MAP = {
149207
4: scalar_types.uint4,
@@ -162,15 +220,19 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
162220
if use_marlin:
163221
from vllm.model_executor.layers.quantization.awq_marlin import (
164222
AWQMarlinConfig, AWQMarlinLinearMethod, AWQMoEMethod)
165-
quant_args_marlin = AWQMarlinConfig(weight_bits=weight_bits,
166-
group_size=group_size,
167-
zero_point=not sym,
168-
lm_head_quantized=False,
169-
full_config={},
170-
modules_to_not_convert=[])
223+
224+
quant_args_marlin = AWQMarlinConfig(
225+
weight_bits=weight_bits,
226+
group_size=group_size,
227+
zero_point=not sym,
228+
lm_head_quantized=False,
229+
full_config={},
230+
modules_to_not_convert=[],
231+
)
171232
else:
172233
from vllm.model_executor.layers.quantization.awq import (
173234
AWQConfig, AWQLinearMethod)
235+
174236
quant_args = AWQConfig(
175237
weight_bits=weight_bits,
176238
group_size=group_size,
@@ -182,6 +244,7 @@ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
182244
return AWQMoEMethod(quant_args_marlin)
183245
from vllm.model_executor.layers.quantization.moe_wna16 import (
184246
MoeWNA16Config)
247+
185248
config = {
186249
"quant_method": "awq",
187250
"bits": weight_bits,
@@ -206,26 +269,32 @@ def apply_gptq_quant_layer(self,
206269
from vllm.model_executor.layers.fused_moe import FusedMoE
207270
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
208271
check_marlin_supported, check_moe_marlin_supports_layer)
272+
209273
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
210274
if not self.check_quantized(weight_bits):
211275
if isinstance(layer, (LinearBase, ParallelLMHead)):
212276
return UnquantizedLinearMethod()
213277
else:
214278
return None
215279

216-
logger.debug("[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
217-
prefix, layer.__class__.__name__, weight_bits, group_size,
218-
sym)
280+
logger.debug(
281+
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
282+
prefix,
283+
layer.__class__.__name__,
284+
weight_bits,
285+
group_size,
286+
sym,
287+
)
219288
if backend == "auto" or "marlin" in backend:
220289
GPTQ_TYPE_MAP = {
221290
(4, True): scalar_types.uint4b8,
222291
(8, True): scalar_types.uint8b128,
223292
}
224-
use_marlin = ((weight_bits, sym) in GPTQ_TYPE_MAP
225-
and check_marlin_supported(
293+
use_marlin = (weight_bits,
294+
sym) in GPTQ_TYPE_MAP and check_marlin_supported(
226295
GPTQ_TYPE_MAP[(weight_bits, sym)],
227296
group_size,
228-
has_zp=not sym))
297+
has_zp=not sym)
229298
if isinstance(layer, FusedMoE):
230299
use_marlin = use_marlin and check_moe_marlin_supports_layer(
231300
layer, group_size)
@@ -234,26 +303,33 @@ def apply_gptq_quant_layer(self,
234303
if use_marlin:
235304
from vllm.model_executor.layers.quantization.gptq_marlin import (
236305
GPTQMarlinConfig, GPTQMarlinLinearMethod, GPTQMarlinMoEMethod)
237-
quant_args_marlin = GPTQMarlinConfig(weight_bits=weight_bits,
238-
group_size=group_size,
239-
is_sym=sym,
240-
lm_head_quantized=False,
241-
desc_act=False,
242-
dynamic={},
243-
full_config={})
306+
307+
quant_args_marlin = GPTQMarlinConfig(
308+
weight_bits=weight_bits,
309+
group_size=group_size,
310+
is_sym=sym,
311+
lm_head_quantized=False,
312+
desc_act=False,
313+
dynamic={},
314+
full_config={},
315+
)
244316
else:
245317
from vllm.model_executor.layers.quantization.gptq import (
246318
GPTQConfig, GPTQLinearMethod)
247-
quant_args = GPTQConfig(weight_bits=weight_bits,
248-
group_size=group_size,
249-
lm_head_quantized=False,
250-
desc_act=False,
251-
dynamic={})
319+
320+
quant_args = GPTQConfig(
321+
weight_bits=weight_bits,
322+
group_size=group_size,
323+
lm_head_quantized=False,
324+
desc_act=False,
325+
dynamic={},
326+
)
252327

253328
if isinstance(layer, FusedMoE):
254329
if use_marlin:
255330
from vllm.model_executor.layers.quantization.moe_wna16 import (
256331
MoeWNA16Config)
332+
257333
config = {
258334
"quant_method": "gptq",
259335
"bits": weight_bits,
@@ -282,6 +358,7 @@ def apply_ipex_quant_layer(self, layer, prefix: str):
282358
return None
283359
from vllm.model_executor.layers.quantization.ipex_quant import (
284360
IPEXAWQLinearMethod, IPEXConfig, IPEXGPTQLinearMethod)
361+
285362
if isinstance(layer, (LinearBase, ParallelLMHead)):
286363
if "awq" in self.packing_format:
287364
config = IPEXConfig(method="awq",

0 commit comments

Comments
 (0)