Skip to content

Commit ec152c8

Browse files
hmellorIsotr0py
andauthored
Fix GPTQ model loading in Transformers backend (#25770)
Signed-off-by: Harry Mellor <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent 7977e50 commit ec152c8

File tree

3 files changed

+29
-10
lines changed

3 files changed

+29
-10
lines changed

tests/models/test_transformers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,9 @@ def test_distributed(
100100
kwargs_test=kwargs)
101101

102102

103-
@pytest.mark.skipif(
104-
current_platform.is_rocm(),
105-
reason="bitsandbytes quantization is currently not supported in rocm.")
106103
@pytest.mark.parametrize("model, quantization_kwargs", [
104+
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {}),
105+
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {}),
107106
(
108107
"meta-llama/Llama-3.2-1B-Instruct",
109108
{
@@ -121,6 +120,11 @@ def test_quantization(
121120
max_tokens: int,
122121
num_logprobs: int,
123122
) -> None:
123+
if (current_platform.is_rocm()
124+
and quantization_kwargs.get("quantization", "") == "bitsandbytes"):
125+
pytest.skip(
126+
"bitsandbytes quantization is currently not supported in rocm.")
127+
124128
with vllm_runner(
125129
model, model_impl="auto", enforce_eager=True,
126130
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]

vllm/model_executor/models/transformers.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
447447
self.device_config: DeviceConfig = vllm_config.device_config
448448
self.model_config: ModelConfig = vllm_config.model_config
449449
self.parallel_config: ParallelConfig = vllm_config.parallel_config
450-
self.quant_config: QuantizationConfig = vllm_config.quant_config
450+
self.quant_config: Optional[
451+
QuantizationConfig] = vllm_config.quant_config
451452

452453
self.pp_group = get_pp_group()
453454
self.pp_size = self.pp_group.world_size
@@ -456,7 +457,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
456457

457458
# Weights to skip in `self.load_weights`
458459
self.skip_prefixes: list[str] = []
460+
"""Skip loading weights whose qualname starts with these prefixes."""
459461
self.skip_substrs: list[str] = []
462+
"""Skip loading weights whose qualname contains these substrings."""
463+
self.ignore_unexpected_prefixes: list[str] = []
464+
"""Ignore unexpected weights whose qualname starts with these prefixes.
465+
"""
466+
self.ignore_unexpected_suffixes: list[str] = []
467+
"""Ignore unexpected weights whose qualname ends with these suffixes."""
468+
469+
# Skip loading extra bias for GPTQ models.
470+
if self.quant_config and "gptq" in self.quant_config.get_name():
471+
self.ignore_unexpected_suffixes.append(".bias")
460472

461473
# Set correct attn and init on "meta" to delay allocating GPU tensors
462474
# TODO: @raushan, use the public `model.set_attn_implementation()`
@@ -563,9 +575,7 @@ def tensor_parallel(self):
563575
raise ValueError(
564576
f"{type(self.model)} does not support tensor parallel. {tip}")
565577

566-
def _tensor_parallel(module: nn.Module,
567-
prefix: str = "",
568-
tp_plan=None):
578+
def _tensor_parallel(module: nn.Module, prefix: str, tp_plan=None):
569579
tp_plan = tp_plan or {}
570580

571581
# If the current module is a PreTrainedModel, set the tp_plan for
@@ -597,7 +607,7 @@ def _tensor_parallel(module: nn.Module,
597607
prefix=qual_name,
598608
tp_plan=tp_plan)
599609

600-
_tensor_parallel(self.model)
610+
_tensor_parallel(self.model, prefix="model")
601611

602612
def create_attention_instances(
603613
self,
@@ -696,6 +706,8 @@ def load_weights(self, weights: Iterable[tuple[str,
696706
self,
697707
skip_prefixes=self.skip_prefixes,
698708
skip_substrs=self.skip_substrs,
709+
ignore_unexpected_prefixes=self.ignore_unexpected_prefixes,
710+
ignore_unexpected_suffixes=self.ignore_unexpected_suffixes,
699711
)
700712
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
701713

vllm/model_executor/models/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,15 @@ def __init__(
109109
skip_prefixes: Optional[list[str]] = None,
110110
skip_substrs: Optional[list[str]] = None,
111111
ignore_unexpected_prefixes: Optional[list[str]] = None,
112+
ignore_unexpected_suffixes: Optional[list[str]] = None,
112113
) -> None:
113114
super().__init__()
114115

115116
self.module = module
116117
self.skip_prefixes = skip_prefixes or []
117118
self.skip_substrs = skip_substrs or []
118119
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
120+
self.ignore_unexpected_suffixes = ignore_unexpected_suffixes or []
119121
# update default skip_substrs
120122
self.skip_substrs += self.ROTARY_EMBEDS_UNUSED_WEIGHTS
121123

@@ -149,8 +151,9 @@ def _can_skip(self, qualname: str) -> bool:
149151
or any(substr in qualname for substr in self.skip_substrs))
150152

151153
def _can_ignore_unexpected(self, qualname: str) -> bool:
152-
return any(
153-
qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
154+
iup = (qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
155+
ius = (qualname.endswith(s) for s in self.ignore_unexpected_suffixes)
156+
return any(iup) or any(ius)
154157

155158
def _load_param(
156159
self,

0 commit comments

Comments
 (0)