Skip to content
5 changes: 1 addition & 4 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,7 @@ steps:
- pytest -v -s models/test_registry.py
- pytest -v -s models/test_utils.py
- pytest -v -s models/test_vision.py
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
- pytest -v -s models/test_initialization.py

- label: Language Models Test (Standard)
mirror_hardwares: [amdexperimental]
Expand Down
32 changes: 24 additions & 8 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,18 @@ class _HfExamplesInfo:
trust_remote_code: bool = False
"""The ``trust_remote_code`` level required to load the model."""

v0_only: bool = False
"""The model is only available with the vLLM V0 engine."""

hf_overrides: dict[str, Any] = field(default_factory=dict)
"""The ``hf_overrides`` required to load the model."""

max_model_len: Optional[int] = None
"""
The maximum model length to use for this model. Some models default to a
length that is too large to fit into memory in CI.
"""

def check_transformers_version(
self,
*,
Expand Down Expand Up @@ -215,10 +224,11 @@ def check_available_online(
"OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat",
trust_remote_code=True),
"PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"),
"PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ - do we know why these are failing? Is it the head_dim?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is head_dim for most of these. The only exceptions IIRC are Kimi and Pixtral, which have tracking issue

"Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"),
"Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct",
trust_remote_code=True),
trust_remote_code=True,
v0_only=True),
"PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct",
trust_remote_code=True),
"Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b",
Expand All @@ -234,7 +244,8 @@ def check_available_online(
is_available_online=False),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501
is_available_online=False),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t",
v0_only=True),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
Expand Down Expand Up @@ -303,7 +314,8 @@ def check_available_online(
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501
v0_only=True),
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
Expand All @@ -328,9 +340,11 @@ def check_available_online(
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
trust_remote_code=True),
trust_remote_code=True,
v0_only=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
min_transformers_version="4.51"),
min_transformers_version="4.51",
max_model_len=10240),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
"mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501
Expand All @@ -349,7 +363,8 @@ def check_available_online(
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
trust_remote_code=True),
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
trust_remote_code=True),
trust_remote_code=True,
v0_only=True),
"Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501
extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
Expand All @@ -372,7 +387,8 @@ def check_available_online(
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
trust_remote_code=True),
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501
tokenizer_mode="mistral"),
tokenizer_mode="mistral",
v0_only=True),
"QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL",
extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501
trust_remote_code=True,
Expand Down
17 changes: 13 additions & 4 deletions tests/models/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@


@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
def test_can_initialize(model_arch):
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")

# Avoid OOM
# Avoid OOM and reduce initialization time by only using 1 layer
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
hf_config.update(model_info.hf_overrides)

Expand All @@ -34,6 +34,12 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
"num_local_experts": 2,
})

if hasattr(hf_config, "vision_config"):
hf_config.vision_config.update({
"num_layers": 1,
"num_hidden_layers": 1,
})

return hf_config

# Avoid calling model.forward()
Expand All @@ -46,7 +52,7 @@ def _initialize_kv_caches_v1(self, vllm_config):
scheduler_kv_cache_config = get_kv_cache_config(
vllm_config,
kv_cache_specs[0],
20 * GiB_bytes,
10 * GiB_bytes,
)

# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
Expand All @@ -55,7 +61,9 @@ def _initialize_kv_caches_v1(self, vllm_config):
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
_initialize_kv_caches_v0),
patch.object(V1EngineCore, "_initialize_kv_caches",
_initialize_kv_caches_v1)):
_initialize_kv_caches_v1), monkeypatch.context() as m):
if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0")
LLM(
model_info.default,
tokenizer=model_info.tokenizer,
Expand All @@ -65,6 +73,7 @@ def _initialize_kv_caches_v1(self, vllm_config):
"num_speculative_tokens": 1,
} if model_info.speculative_model else None,
trust_remote_code=model_info.trust_remote_code,
max_model_len=model_info.max_model_len,
load_format="dummy",
hf_overrides=hf_overrides,
)
30 changes: 7 additions & 23 deletions vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch.nn.functional as F
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -182,25 +182,20 @@ def __init__(
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap,
prefix=f"{prefix}.attn")
self.attn_multiplier = getattr(self.config, "attn_output_multiplier",
1.0) if self.config else 1.0

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)

# Apply attention output multiplier if specified in config
attn_multiplier = getattr(self.config, "attn_output_multiplier",
None) if self.config else None
if attn_multiplier is not None:
output = output * attn_multiplier
output *= self.attn_multiplier
return output


Expand Down Expand Up @@ -261,8 +256,6 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
Expand All @@ -276,8 +269,6 @@ def forward(
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)

# Post attention normalization
Expand Down Expand Up @@ -341,8 +332,6 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
Expand All @@ -359,9 +348,7 @@ def forward(

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
hidden_states, residual = layer(positions, hidden_states, residual)

if not get_pp_group().is_last_rank:
return IntermediateTensors({
Expand Down Expand Up @@ -529,13 +516,10 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states

Expand Down
15 changes: 9 additions & 6 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2794,14 +2794,17 @@ def wrapper(*args, **kwargs):

# Only relevant for models using ALiBi (e.g, MPT)
def check_use_alibi(model_config: ModelConfig) -> bool:
return (getattr(model_config.hf_text_config, "alibi", False) # Falcon
cfg = model_config.hf_text_config
return (getattr(cfg, "alibi", False) # Falcon
or ("BloomForCausalLM" in getattr(model_config.hf_config,
"architectures", [])) # Bloom
or getattr(model_config.hf_text_config, "position_encoding_type",
"") == "alibi" # codellm_1b_alibi
or
(hasattr(model_config.hf_text_config, "attn_config") # MPT
and model_config.hf_text_config.attn_config.get("alibi", False)))
or getattr(cfg, "position_encoding_type", "") ==
"alibi" # codellm_1b_alibi
or (hasattr(cfg, "attn_config") # MPT
and ((isinstance(cfg.attn_config, dict)
and cfg.attn_config.get("alibi", False)) or
(not isinstance(cfg.attn_config, dict)
and getattr(cfg.attn_config, "alibi", False)))))


def sha256(input) -> int:
Expand Down