diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 181fbda57b3f..1f54b70f05dd 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -472,10 +472,7 @@ steps: - pytest -v -s models/test_registry.py - pytest -v -s models/test_utils.py - pytest -v -s models/test_vision.py - # V1 Test: https://github.com/vllm-project/vllm/issues/14531 - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2' - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4' - - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2' + - pytest -v -s models/test_initialization.py - label: Language Models Test (Standard) mirror_hardwares: [amdexperimental] diff --git a/tests/models/registry.py b/tests/models/registry.py index 911a58e99d4c..22d532aa71e0 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -55,9 +55,18 @@ class _HfExamplesInfo: trust_remote_code: bool = False """The ``trust_remote_code`` level required to load the model.""" + v0_only: bool = False + """The model is only available with the vLLM V0 engine.""" + hf_overrides: dict[str, Any] = field(default_factory=dict) """The ``hf_overrides`` required to load the model.""" + max_model_len: Optional[int] = None + """ + The maximum model length to use for this model. Some models default to a + length that is too large to fit into memory in CI. + """ + def check_transformers_version( self, *, @@ -215,10 +224,11 @@ def check_available_online( "OrionForCausalLM": _HfExamplesInfo("OrionStarAI/Orion-14B-Chat", trust_remote_code=True), "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), - "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), + "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2", v0_only=True), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), "Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct", - trust_remote_code=True), + trust_remote_code=True, + v0_only=True), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", @@ -234,7 +244,8 @@ def check_available_online( is_available_online=False), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b", # noqa: E501 is_available_online=False), - "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), + "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t", + v0_only=True), "Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"), "SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct"), "TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B", @@ -303,7 +314,8 @@ def check_available_online( "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"), "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereForAI/aya-vision-8b"), # noqa: E501 "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501 - extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501 + extras={"6b": "Salesforce/blip2-opt-6.7b"}, # noqa: E501 + v0_only=True), "ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501 "DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501 extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501 @@ -328,9 +340,11 @@ def check_available_online( {"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}), # noqa: E501 "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 - trust_remote_code=True), + trust_remote_code=True, + v0_only=True), "Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501 - min_transformers_version="4.51"), + min_transformers_version="4.51", + max_model_len=10240), "LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf", extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501 "mistral-fp8": "nm-testing/pixtral-12b-FP8-dynamic"}), # noqa: E501 @@ -349,7 +363,8 @@ def check_available_online( extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501 trust_remote_code=True), "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501 - trust_remote_code=True), + trust_remote_code=True, + v0_only=True), "Mistral3ForConditionalGeneration": _HfExamplesInfo("mistralai/Mistral-Small-3.1-24B-Instruct-2503", # noqa: E501 extras={"fp8": "nm-testing/Mistral-Small-3.1-24B-Instruct-2503-FP8-dynamic"}), # noqa: E501 "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", @@ -372,7 +387,8 @@ def check_available_online( "Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct", trust_remote_code=True), "PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501 - tokenizer_mode="mistral"), + tokenizer_mode="mistral", + v0_only=True), "QwenVLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen-VL", extras={"chat": "Qwen/Qwen-VL-Chat"}, # noqa: E501 trust_remote_code=True, diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 446c4efbf6af..d403cb392fe0 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -15,12 +15,12 @@ @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) -def test_can_initialize(model_arch): +def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") - # Avoid OOM + # Avoid OOM and reduce initialization time by only using 1 layer def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update(model_info.hf_overrides) @@ -34,6 +34,12 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: "num_local_experts": 2, }) + if hasattr(hf_config, "vision_config"): + hf_config.vision_config.update({ + "num_layers": 1, + "num_hidden_layers": 1, + }) + return hf_config # Avoid calling model.forward() @@ -46,7 +52,7 @@ def _initialize_kv_caches_v1(self, vllm_config): scheduler_kv_cache_config = get_kv_cache_config( vllm_config, kv_cache_specs[0], - 20 * GiB_bytes, + 10 * GiB_bytes, ) # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config @@ -55,7 +61,9 @@ def _initialize_kv_caches_v1(self, vllm_config): with (patch.object(V0LLMEngine, "_initialize_kv_caches", _initialize_kv_caches_v0), patch.object(V1EngineCore, "_initialize_kv_caches", - _initialize_kv_caches_v1)): + _initialize_kv_caches_v1), monkeypatch.context() as m): + if model_info.v0_only: + m.setenv("VLLM_USE_V1", "0") LLM( model_info.default, tokenizer=model_info.tokenizer, @@ -65,6 +73,7 @@ def _initialize_kv_caches_v1(self, vllm_config): "num_speculative_tokens": 1, } if model_info.speculative_model else None, trust_remote_code=model_info.trust_remote_code, + max_model_len=model_info.max_model_len, load_format="dummy", hf_overrides=hf_overrides, ) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 578d31a851a9..bc9e9a3c0206 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -28,7 +28,7 @@ import torch.nn.functional as F from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -182,25 +182,20 @@ def __init__( quant_config=quant_config, logits_soft_cap=attn_logits_soft_cap, prefix=f"{prefix}.attn") + self.attn_multiplier = getattr(self.config, "attn_output_multiplier", + 1.0) if self.config else 1.0 def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) - - # Apply attention output multiplier if specified in config - attn_multiplier = getattr(self.config, "attn_output_multiplier", - None) if self.config else None - if attn_multiplier is not None: - output = output * attn_multiplier + output *= self.attn_multiplier return output @@ -261,8 +256,6 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -276,8 +269,6 @@ def forward( hidden_states = self.attn( positions=positions, hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, ) # Post attention normalization @@ -341,8 +332,6 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: list[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: @@ -359,9 +348,7 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, residual) + hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -529,13 +516,10 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - kv_caches: list[torch.Tensor], - attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, + hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm/utils.py b/vllm/utils.py index fcc0ab3b237a..25694c121581 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2794,14 +2794,17 @@ def wrapper(*args, **kwargs): # Only relevant for models using ALiBi (e.g, MPT) def check_use_alibi(model_config: ModelConfig) -> bool: - return (getattr(model_config.hf_text_config, "alibi", False) # Falcon + cfg = model_config.hf_text_config + return (getattr(cfg, "alibi", False) # Falcon or ("BloomForCausalLM" in getattr(model_config.hf_config, "architectures", [])) # Bloom - or getattr(model_config.hf_text_config, "position_encoding_type", - "") == "alibi" # codellm_1b_alibi - or - (hasattr(model_config.hf_text_config, "attn_config") # MPT - and model_config.hf_text_config.attn_config.get("alibi", False))) + or getattr(cfg, "position_encoding_type", "") == + "alibi" # codellm_1b_alibi + or (hasattr(cfg, "attn_config") # MPT + and ((isinstance(cfg.attn_config, dict) + and cfg.attn_config.get("alibi", False)) or + (not isinstance(cfg.attn_config, dict) + and getattr(cfg.attn_config, "alibi", False))))) def sha256(input) -> int: