Skip to content

Commit cb6882b

Browse files
mgoinrasmith
authored andcommitted
[Bugfix] Fix prefix strings for quantized VLMs (vllm-project#9772)
Signed-off-by: Randall Smith <[email protected]>
1 parent 48b2a63 commit cb6882b

20 files changed

+288
-97
lines changed

vllm/model_executor/model_loader/loader.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,20 @@ def _get_model_initialization_kwargs(
147147
return extra_kwargs
148148

149149

150-
def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
150+
def build_model(model_class: Type[nn.Module],
151+
hf_config: PretrainedConfig,
151152
cache_config: Optional[CacheConfig],
152-
quant_config: Optional[QuantizationConfig], *,
153+
quant_config: Optional[QuantizationConfig],
154+
*,
153155
lora_config: Optional[LoRAConfig],
154156
multimodal_config: Optional[MultiModalConfig],
155-
scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
157+
scheduler_config: Optional[SchedulerConfig],
158+
prefix: Optional[str] = None) -> nn.Module:
156159
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
157160
multimodal_config,
158161
scheduler_config)
162+
if prefix:
163+
extra_kwargs["prefix"] = prefix
159164

160165
return model_class(config=hf_config,
161166
cache_config=cache_config,

vllm/model_executor/models/blip2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,10 @@ def __init__(self,
507507
)
508508

509509
self.language_model = init_vllm_registered_model(
510-
config.text_config, cache_config, quant_config)
510+
config.text_config,
511+
cache_config,
512+
quant_config,
513+
prefix="language_model")
511514

512515
self.make_empty_intermediate_tensors = (
513516
self.language_model.make_empty_intermediate_tensors)

vllm/model_executor/models/gemma.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343

4444
from .interfaces import SupportsLoRA, SupportsPP
4545
from .utils import (is_pp_missing_parameter,
46-
make_empty_intermediate_tensors_factory, make_layers)
46+
make_empty_intermediate_tensors_factory, make_layers,
47+
maybe_prefix)
4748

4849
logger = init_logger(__name__)
4950

@@ -83,16 +84,23 @@ def __init__(
8384
hidden_act: Optional[str] = None,
8485
hidden_activation: Optional[str] = None,
8586
quant_config: Optional[QuantizationConfig] = None,
87+
prefix: str = "",
8688
) -> None:
8789
super().__init__()
8890
self.gate_up_proj = MergedColumnParallelLinear(
89-
hidden_size, [intermediate_size] * 2,
91+
hidden_size,
92+
[intermediate_size] * 2,
9093
bias=False,
91-
quant_config=quant_config)
92-
self.down_proj = RowParallelLinear(intermediate_size,
93-
hidden_size,
94-
bias=False,
95-
quant_config=quant_config)
94+
quant_config=quant_config,
95+
prefix=f"{prefix}.gate_up_proj",
96+
)
97+
self.down_proj = RowParallelLinear(
98+
intermediate_size,
99+
hidden_size,
100+
bias=False,
101+
quant_config=quant_config,
102+
prefix=f"{prefix}.down_proj",
103+
)
96104
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
97105

98106
def forward(self, x):
@@ -104,15 +112,18 @@ def forward(self, x):
104112

105113
class GemmaAttention(nn.Module):
106114

107-
def __init__(self,
108-
hidden_size: int,
109-
num_heads: int,
110-
num_kv_heads: int,
111-
head_dim: int,
112-
max_position_embeddings: int = 8192,
113-
rope_theta: float = 10000,
114-
cache_config: Optional[CacheConfig] = None,
115-
quant_config: Optional[QuantizationConfig] = None) -> None:
115+
def __init__(
116+
self,
117+
hidden_size: int,
118+
num_heads: int,
119+
num_kv_heads: int,
120+
head_dim: int,
121+
max_position_embeddings: int = 8192,
122+
rope_theta: float = 10000,
123+
cache_config: Optional[CacheConfig] = None,
124+
quant_config: Optional[QuantizationConfig] = None,
125+
prefix: str = "",
126+
) -> None:
116127
super().__init__()
117128
self.hidden_size = hidden_size
118129
tp_size = get_tensor_model_parallel_world_size()
@@ -142,12 +153,14 @@ def __init__(self,
142153
self.total_num_kv_heads,
143154
bias=False,
144155
quant_config=quant_config,
156+
prefix=f"{prefix}.qkv_proj",
145157
)
146158
self.o_proj = RowParallelLinear(
147159
self.total_num_heads * self.head_dim,
148160
hidden_size,
149161
bias=False,
150162
quant_config=quant_config,
163+
prefix=f"{prefix}.o_proj",
151164
)
152165

153166
self.rotary_emb = get_rope(
@@ -186,6 +199,7 @@ def __init__(
186199
config: GemmaConfig,
187200
cache_config: Optional[CacheConfig] = None,
188201
quant_config: Optional[QuantizationConfig] = None,
202+
prefix: str = "",
189203
) -> None:
190204
super().__init__()
191205
self.hidden_size = config.hidden_size
@@ -198,13 +212,15 @@ def __init__(
198212
rope_theta=config.rope_theta,
199213
cache_config=cache_config,
200214
quant_config=quant_config,
215+
prefix=f"{prefix}.self_attn",
201216
)
202217
self.mlp = GemmaMLP(
203218
hidden_size=self.hidden_size,
204219
intermediate_size=config.intermediate_size,
205220
hidden_act=config.hidden_act,
206221
hidden_activation=getattr(config, "hidden_activation", None),
207222
quant_config=quant_config,
223+
prefix=f"{prefix}.mlp",
208224
)
209225
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
210226
eps=config.rms_norm_eps)
@@ -259,8 +275,8 @@ def __init__(
259275
)
260276
self.start_layer, self.end_layer, self.layers = make_layers(
261277
config.num_hidden_layers,
262-
lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config
263-
),
278+
lambda prefix: GemmaDecoderLayer(
279+
config, cache_config, quant_config, prefix=prefix),
264280
prefix=f"{prefix}.layers")
265281
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
266282

@@ -366,6 +382,7 @@ def __init__(
366382
cache_config: Optional[CacheConfig] = None,
367383
quant_config: Optional[QuantizationConfig] = None,
368384
lora_config: Optional[LoRAConfig] = None,
385+
prefix: str = "",
369386
) -> None:
370387
super().__init__()
371388

@@ -375,7 +392,10 @@ def __init__(
375392
self.lora_config = lora_config
376393

377394
self.quant_config = quant_config
378-
self.model = GemmaModel(config, cache_config, quant_config)
395+
self.model = GemmaModel(config,
396+
cache_config,
397+
quant_config,
398+
prefix=maybe_prefix(prefix, "model"))
379399
self.logits_processor = LogitsProcessor(config.vocab_size)
380400
self.sampler = Sampler()
381401
self.make_empty_intermediate_tensors = (

vllm/model_executor/models/internlm2.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030

3131
from .interfaces import SupportsPP
3232
from .utils import (is_pp_missing_parameter,
33-
make_empty_intermediate_tensors_factory, make_layers)
33+
make_empty_intermediate_tensors_factory, make_layers,
34+
maybe_prefix)
3435

3536

3637
class InternLM2MLP(nn.Module):
@@ -41,16 +42,23 @@ def __init__(
4142
intermediate_size: int,
4243
hidden_act: str,
4344
quant_config: Optional[QuantizationConfig] = None,
45+
prefix: str = "",
4446
) -> None:
4547
super().__init__()
4648
self.gate_up_proj = MergedColumnParallelLinear(
47-
hidden_size, [intermediate_size] * 2,
49+
hidden_size,
50+
[intermediate_size] * 2,
51+
bias=False,
52+
quant_config=quant_config,
53+
prefix=f"{prefix}.gate_up_proj",
54+
)
55+
self.w2 = RowParallelLinear(
56+
intermediate_size,
57+
hidden_size,
4858
bias=False,
49-
quant_config=quant_config)
50-
self.w2 = RowParallelLinear(intermediate_size,
51-
hidden_size,
52-
bias=False,
53-
quant_config=quant_config)
59+
quant_config=quant_config,
60+
prefix=f"{prefix}.w2",
61+
)
5462
if hidden_act != "silu":
5563
raise ValueError(f"Unsupported activation: {hidden_act}. "
5664
"Only silu is supported for now.")
@@ -75,6 +83,7 @@ def __init__(
7583
max_position_embeddings: int = 8192,
7684
cache_config: Optional[CacheConfig] = None,
7785
quant_config: Optional[QuantizationConfig] = None,
86+
prefix: str = "",
7887
) -> None:
7988
super().__init__()
8089
self.hidden_size = hidden_size
@@ -108,12 +117,14 @@ def __init__(
108117
self.total_num_kv_heads,
109118
bias=False,
110119
quant_config=quant_config,
120+
prefix=f"{prefix}.wqkv",
111121
)
112122
self.wo = RowParallelLinear(
113123
self.total_num_heads * self.head_dim,
114124
hidden_size,
115125
bias=False,
116126
quant_config=quant_config,
127+
prefix=f"{prefix}.wo",
117128
)
118129

119130
self.rotary_emb = get_rope(
@@ -123,12 +134,15 @@ def __init__(
123134
base=rope_theta,
124135
rope_scaling=rope_scaling,
125136
)
126-
self.attn = Attention(self.num_heads,
127-
self.head_dim,
128-
self.scaling,
129-
num_kv_heads=self.num_kv_heads,
130-
cache_config=cache_config,
131-
quant_config=quant_config)
137+
self.attn = Attention(
138+
self.num_heads,
139+
self.head_dim,
140+
self.scaling,
141+
num_kv_heads=self.num_kv_heads,
142+
cache_config=cache_config,
143+
quant_config=quant_config,
144+
prefix=f"{prefix}.attn",
145+
)
132146

133147
def split_qkv(self, qkv: torch.Tensor):
134148
seq_len = qkv.shape[0]
@@ -176,6 +190,7 @@ def __init__(
176190
config: PretrainedConfig,
177191
cache_config: Optional[CacheConfig] = None,
178192
quant_config: Optional[QuantizationConfig] = None,
193+
prefix: str = "",
179194
) -> None:
180195
super().__init__()
181196
self.hidden_size = config.hidden_size
@@ -192,12 +207,14 @@ def __init__(
192207
max_position_embeddings=max_position_embeddings,
193208
cache_config=cache_config,
194209
quant_config=quant_config,
210+
prefix=f"{prefix}.attention",
195211
)
196212
self.feed_forward = InternLM2MLP(
197213
hidden_size=self.hidden_size,
198214
intermediate_size=config.intermediate_size,
199215
hidden_act=config.hidden_act,
200216
quant_config=quant_config,
217+
prefix=f"{prefix}.feed_forward",
201218
)
202219
self.attention_norm = RMSNorm(config.hidden_size,
203220
eps=config.rms_norm_eps)
@@ -251,8 +268,8 @@ def __init__(
251268
)
252269
self.start_layer, self.end_layer, self.layers = make_layers(
253270
config.num_hidden_layers,
254-
lambda prefix: InternLMDecoderLayer(config, cache_config,
255-
quant_config),
271+
lambda prefix: InternLMDecoderLayer(
272+
config, cache_config, quant_config, prefix=prefix),
256273
prefix=f"{prefix}.layers")
257274
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
258275
self.make_empty_intermediate_tensors = (
@@ -306,14 +323,19 @@ def __init__(
306323
config: PretrainedConfig,
307324
cache_config: Optional[CacheConfig] = None,
308325
quant_config: Optional[QuantizationConfig] = None,
326+
prefix: str = "",
309327
) -> None:
310328
super().__init__()
311329
self.config = config
312330
self.quant_config = quant_config
313-
self.model = InternLM2Model(config, cache_config, quant_config)
331+
self.model = InternLM2Model(config,
332+
cache_config,
333+
quant_config,
334+
prefix=maybe_prefix(prefix, "model"))
314335
self.output = ParallelLMHead(config.vocab_size,
315336
config.hidden_size,
316-
quant_config=quant_config)
337+
quant_config=quant_config,
338+
prefix=maybe_prefix(prefix, "output"))
317339
if self.config.tie_word_embeddings:
318340
self.output.weight = self.model.tok_embeddings.weight
319341
self.logits_processor = LogitsProcessor(config.vocab_size)

vllm/model_executor/models/internlm2_ve.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
InternLM2MLP, InternLM2Model)
1616
from vllm.sequence import IntermediateTensors
1717

18-
from .utils import make_layers
18+
from .utils import make_layers, maybe_prefix
1919

2020

2121
class InternLM2VEDecoderLayer(nn.Module):
@@ -25,6 +25,7 @@ def __init__(
2525
config: PretrainedConfig,
2626
cache_config: Optional[CacheConfig] = None,
2727
quant_config: Optional[QuantizationConfig] = None,
28+
prefix: str = "",
2829
) -> None:
2930
super().__init__()
3031
self.hidden_size = config.hidden_size
@@ -41,18 +42,21 @@ def __init__(
4142
max_position_embeddings=max_position_embeddings,
4243
cache_config=cache_config,
4344
quant_config=quant_config,
45+
prefix=f"{prefix}.attention",
4446
)
4547
self.feed_forward = InternLM2MLP(
4648
hidden_size=self.hidden_size,
4749
intermediate_size=config.intermediate_size,
4850
hidden_act=config.hidden_act,
4951
quant_config=quant_config,
52+
prefix=f"{prefix}.feed_forward",
5053
)
5154
self.feed_forward_ve = InternLM2MLP(
5255
hidden_size=self.hidden_size,
5356
intermediate_size=config.intermediate_size,
5457
hidden_act=config.hidden_act,
5558
quant_config=quant_config,
59+
prefix=f"{prefix}.feed_forward_ve",
5660
)
5761
self.attention_norm = RMSNorm(config.hidden_size,
5862
eps=config.rms_norm_eps)
@@ -111,8 +115,8 @@ def __init__(
111115
super().__init__(config, cache_config, quant_config)
112116
self.start_layer, self.end_layer, self.layers = make_layers(
113117
config.num_hidden_layers,
114-
lambda prefix: InternLM2VEDecoderLayer(config, cache_config,
115-
quant_config),
118+
lambda prefix: InternLM2VEDecoderLayer(
119+
config, cache_config, quant_config, prefix=prefix),
116120
prefix=f"{prefix}.layers")
117121

118122
def forward(
@@ -161,6 +165,10 @@ def __init__(
161165
config: PretrainedConfig,
162166
cache_config: Optional[CacheConfig] = None,
163167
quant_config: Optional[QuantizationConfig] = None,
168+
prefix: str = "",
164169
) -> None:
165170
super().__init__(config, cache_config, quant_config)
166-
self.model = InternLM2VEModel(config, cache_config, quant_config)
171+
self.model = InternLM2VEModel(config,
172+
cache_config,
173+
quant_config,
174+
prefix=maybe_prefix(prefix, "model"))

vllm/model_executor/models/internvl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,10 @@ def __init__(self,
439439
)
440440

441441
self.language_model = init_vllm_registered_model(
442-
config.text_config, cache_config, quant_config)
442+
config.text_config,
443+
cache_config,
444+
quant_config,
445+
prefix="language_model")
443446

444447
self.mlp1 = self._init_mlp1(config)
445448

0 commit comments

Comments
 (0)