Skip to content

Commit 385e63a

Browse files
mgoinYuqi Zhang
authored andcommitted
[CI] Enable test_initialization to run on V1 (vllm-project#16736)
Signed-off-by: mgoin <[email protected]> Signed-off-by: Yuqi Zhang <[email protected]>
1 parent 6998140 commit 385e63a

File tree

5 files changed

+54
-45
lines changed

5 files changed

+54
-45
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,7 @@ steps:
472472
- pytest -v -s models/test_registry.py
473473
- pytest -v -s models/test_utils.py
474474
- pytest -v -s models/test_vision.py
475-
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
476-
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
477-
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
478-
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
475+
- pytest -v -s models/test_initialization.py
479476

480477
- label: Language Models Test (Standard)
481478
mirror_hardwares: [amdexperimental]

tests/models/registry.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,18 @@ class _HfExamplesInfo:
5555
trust_remote_code: bool = False
5656
"""The ``trust_remote_code`` level required to load the model."""
5757

58+
v0_only: bool = False
59+
"""The model is only available with the vLLM V0 engine."""
60+
5861
hf_overrides: dict[str, Any] = field(default_factory=dict)
5962
"""The ``hf_overrides`` required to load the model."""
6063

64+
max_model_len: Optional[int] = None
65+
"""
66+
The maximum model length to use for this model. Some models default to a
67+
length that is too large to fit into memory in CI.
68+
"""
69+
6170
def check_transformers_version(
6271
self,
6372
*,
@@ -215,10 +224,11 @@ def check_available_online(
215224
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
216225
trust_remote_code=True),
217226
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
218-
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
227+
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True),
219228
"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
220229
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
221-
trust_remote_code=True),
230+
trust_remote_code=True,
231+
v0_only=True),
222232
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
223233
trust_remote_code=True),
224234
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
@@ -234,7 +244,8 @@ def check_available_online(
234244
is_available_online=False),
235245
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
236246
is_available_online=False),
237-
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
247+
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t",
248+
v0_only=True),
238249
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
239250
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
240251
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
@@ -303,7 +314,8 @@ def check_available_online(
303314
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
304315
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
305316
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
306-
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
317+
extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501
318+
v0_only=True),
307319
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
308320
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
309321
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
@@ -328,9 +340,11 @@ def check_available_online(
328340
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
329341
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
330342
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
331-
trust_remote_code=True),
343+
trust_remote_code=True,
344+
v0_only=True),
332345
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
333-
min_transformers_version="4.51"),
346+
min_transformers_version="4.51",
347+
max_model_len=10240),
334348
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
335349
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
336350
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
@@ -349,7 +363,8 @@ def check_available_online(
349363
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
350364
trust_remote_code=True),
351365
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
352-
trust_remote_code=True),
366+
trust_remote_code=True,
367+
v0_only=True),
353368
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
354369
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
355370
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
@@ -372,7 +387,8 @@ def check_available_online(
372387
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
373388
trust_remote_code=True),
374389
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
375-
tokenizer_mode="mistral"),
390+
tokenizer_mode="mistral",
391+
v0_only=True),
376392
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
377393
extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501
378394
trust_remote_code=True,

tests/models/test_initialization.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515

1616

1717
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
18-
def test_can_initialize(model_arch):
18+
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
1919
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
2020
model_info.check_available_online(on_fail="skip")
2121
model_info.check_transformers_version(on_fail="skip")
2222

23-
# Avoid OOM
23+
# Avoid OOM and reduce initialization time by only using 1 layer
2424
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
2525
hf_config.update(model_info.hf_overrides)
2626

@@ -34,6 +34,12 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
3434
"num_local_experts": 2,
3535
})
3636

37+
if hasattr(hf_config, "vision_config"):
38+
hf_config.vision_config.update({
39+
"num_layers": 1,
40+
"num_hidden_layers": 1,
41+
})
42+
3743
return hf_config
3844

3945
# Avoid calling model.forward()
@@ -46,7 +52,7 @@ def _initialize_kv_caches_v1(self, vllm_config):
4652
scheduler_kv_cache_config = get_kv_cache_config(
4753
vllm_config,
4854
kv_cache_specs[0],
49-
20 * GiB_bytes,
55+
10 * GiB_bytes,
5056
)
5157

5258
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
@@ -55,7 +61,9 @@ def _initialize_kv_caches_v1(self, vllm_config):
5561
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
5662
_initialize_kv_caches_v0),
5763
patch.object(V1EngineCore, "_initialize_kv_caches",
58-
_initialize_kv_caches_v1)):
64+
_initialize_kv_caches_v1), monkeypatch.context() as m):
65+
if model_info.v0_only:
66+
m.setenv("VLLM_USE_V1", "0")
5967
LLM(
6068
model_info.default,
6169
tokenizer=model_info.tokenizer,
@@ -65,6 +73,7 @@ def _initialize_kv_caches_v1(self, vllm_config):
6573
"num_speculative_tokens": 1,
6674
} if model_info.speculative_model else None,
6775
trust_remote_code=model_info.trust_remote_code,
76+
max_model_len=model_info.max_model_len,
6877
load_format="dummy",
6978
hf_overrides=hf_overrides,
7079
)

vllm/model_executor/models/grok1.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import torch.nn.functional as F
2929
from torch import nn
3030

31-
from vllm.attention import Attention, AttentionMetadata
31+
from vllm.attention import Attention
3232
from vllm.compilation.decorators import support_torch_compile
3333
from vllm.config import CacheConfig, VllmConfig
3434
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -182,25 +182,20 @@ def __init__(
182182
quant_config=quant_config,
183183
logits_soft_cap=attn_logits_soft_cap,
184184
prefix=f"{prefix}.attn")
185+
self.attn_multiplier = getattr(self.config, "attn_output_multiplier",
186+
1.0) if self.config else 1.0
185187

186188
def forward(
187189
self,
188190
positions: torch.Tensor,
189191
hidden_states: torch.Tensor,
190-
kv_cache: torch.Tensor,
191-
attn_metadata: AttentionMetadata,
192192
) -> torch.Tensor:
193193
qkv, _ = self.qkv_proj(hidden_states)
194194
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
195195
q, k = self.rotary_emb(positions, q, k)
196-
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
196+
attn_output = self.attn(q, k, v)
197197
output, _ = self.o_proj(attn_output)
198-
199-
# Apply attention output multiplier if specified in config
200-
attn_multiplier = getattr(self.config, "attn_output_multiplier",
201-
None) if self.config else None
202-
if attn_multiplier is not None:
203-
output = output * attn_multiplier
198+
output *= self.attn_multiplier
204199
return output
205200

206201

@@ -261,8 +256,6 @@ def forward(
261256
self,
262257
positions: torch.Tensor,
263258
hidden_states: torch.Tensor,
264-
kv_cache: torch.Tensor,
265-
attn_metadata: AttentionMetadata,
266259
residual: Optional[torch.Tensor],
267260
) -> tuple[torch.Tensor, torch.Tensor]:
268261
# Self Attention
@@ -276,8 +269,6 @@ def forward(
276269
hidden_states = self.attn(
277270
positions=positions,
278271
hidden_states=hidden_states,
279-
kv_cache=kv_cache,
280-
attn_metadata=attn_metadata,
281272
)
282273

283274
# Post attention normalization
@@ -341,8 +332,6 @@ def forward(
341332
self,
342333
input_ids: torch.Tensor,
343334
positions: torch.Tensor,
344-
kv_caches: list[torch.Tensor],
345-
attn_metadata: AttentionMetadata,
346335
intermediate_tensors: Optional[IntermediateTensors],
347336
inputs_embeds: Optional[torch.Tensor] = None,
348337
) -> Union[torch.Tensor, IntermediateTensors]:
@@ -359,9 +348,7 @@ def forward(
359348

360349
for i in range(self.start_layer, self.end_layer):
361350
layer = self.layers[i]
362-
hidden_states, residual = layer(positions, hidden_states,
363-
kv_caches[i - self.start_layer],
364-
attn_metadata, residual)
351+
hidden_states, residual = layer(positions, hidden_states, residual)
365352

366353
if not get_pp_group().is_last_rank:
367354
return IntermediateTensors({
@@ -529,13 +516,10 @@ def forward(
529516
self,
530517
input_ids: torch.Tensor,
531518
positions: torch.Tensor,
532-
kv_caches: list[torch.Tensor],
533-
attn_metadata: AttentionMetadata,
534519
intermediate_tensors: Optional[IntermediateTensors] = None,
535520
inputs_embeds: Optional[torch.Tensor] = None,
536521
) -> Union[torch.Tensor, IntermediateTensors]:
537-
hidden_states = self.model(input_ids, positions, kv_caches,
538-
attn_metadata, intermediate_tensors,
522+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
539523
inputs_embeds)
540524
return hidden_states
541525

vllm/utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2794,14 +2794,17 @@ def wrapper(*args, **kwargs):
27942794

27952795
# Only relevant for models using ALiBi (e.g, MPT)
27962796
def check_use_alibi(model_config: ModelConfig) -> bool:
2797-
return (getattr(model_config.hf_text_config, "alibi", False) # Falcon
2797+
cfg = model_config.hf_text_config
2798+
return (getattr(cfg, "alibi", False) # Falcon
27982799
or ("BloomForCausalLM" in getattr(model_config.hf_config,
27992800
"architectures", [])) # Bloom
2800-
or getattr(model_config.hf_text_config, "position_encoding_type",
2801-
"") == "alibi" # codellm_1b_alibi
2802-
or
2803-
(hasattr(model_config.hf_text_config, "attn_config") # MPT
2804-
and model_config.hf_text_config.attn_config.get("alibi", False)))
2801+
or getattr(cfg, "position_encoding_type", "") ==
2802+
"alibi" # codellm_1b_alibi
2803+
or (hasattr(cfg, "attn_config") # MPT
2804+
and ((isinstance(cfg.attn_config, dict)
2805+
and cfg.attn_config.get("alibi", False)) or
2806+
(not isinstance(cfg.attn_config, dict)
2807+
and getattr(cfg.attn_config, "alibi", False)))))
28052808

28062809

28072810
def sha256(input) -> int:

0 commit comments

Comments
 (0)