diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md index edd9a47e132f..21b1f21d60a3 100644 --- a/docs/contributing/model/basic.md +++ b/docs/contributing/model/basic.md @@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m To support a model with interleaving sliding windows, we need to take care of the following details: -- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model. +- Make sure the model's `config.json` contains `layer_types`. - In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171). With these two steps, interleave sliding windows should work with the model. diff --git a/tests/test_config.py b/tests/test_config.py index 441c07b99acf..19b1b74e4269 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -200,28 +200,6 @@ def test_disable_sliding_window(model_id_expected): assert model_config.max_model_len == expected -def test_get_sliding_window(): - TEST_SLIDING_WINDOW = 4096 - # Test that the sliding window is correctly computed. - # For Qwen1.5/Qwen2, get_sliding_window() should be None - # when use_sliding_window is False. - qwen2_model_config = ModelConfig("Qwen/Qwen1.5-7B") - - qwen2_model_config.hf_config.use_sliding_window = False - qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW - assert qwen2_model_config.get_sliding_window() is None - - qwen2_model_config.hf_config.use_sliding_window = True - assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW - - mistral_model_config = ModelConfig("mistralai/Mistral-7B-v0.1") - mistral_model_config.hf_config.sliding_window = None - assert mistral_model_config.get_sliding_window() is None - - mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW - assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW - - @pytest.mark.skipif(current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm.") def test_get_pooling_config(): diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 69c05b75d3eb..6b622e3ab7ea 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -40,8 +40,9 @@ ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, - maybe_override_with_speculators_target_model, try_get_generation_config, - try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope) + is_interleaved, maybe_override_with_speculators_target_model, + try_get_generation_config, try_get_safetensors_metadata, + try_get_tokenizer_config, uses_mrope) from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect # yapf conflicts with isort for this block @@ -721,53 +722,31 @@ def _task_to_convert(task: TaskOption) -> ConvertType: revision=self.revision, ) - # Workaround for Gemma 2 which uses interleaved sliding window - # attention, but it's not specified in its config. - # TODO: remove this when Gemma 2 config updated in HuggingFace. - if self.hf_text_config.model_type == "gemma2": - self.hf_text_config.sliding_window_pattern = 2 - - # TODO: remove this when Gemma 3n config updated in HuggingFace. - if self.hf_text_config.model_type == "gemma3n_text": - # 4 sliding window attention followed by 1 full attention - self.hf_text_config.sliding_window_pattern = "LLLLG" - - sliding_window = getattr(self.hf_text_config, "sliding_window", None) - sliding_window_pattern = getattr(self.hf_text_config, - "sliding_window_pattern", None) - has_interleaved_attention = sliding_window_pattern is not None or ( - isinstance(sliding_window, list)) - - if not self.disable_sliding_window and has_interleaved_attention: - if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND - ) in ("XFORMERS", "FLASHINFER"): - sliding_window_len_min = get_min_sliding_window( - self.hf_text_config.sliding_window) - - logger.warning_once( - "%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501 - self.hf_text_config.model_type, - backend, - sliding_window_len_min, - ) - self.disable_sliding_window = True - else: - # for a model with interleaved attention, - # the scheduler and the model treat it as full attention - # (i.e., not dropping any tokens outside the window). - # only the attention layer itself is aware of the sliding - # window, and use the window size to compute the attention. - self.hf_text_config.interleaved_sliding_window = sliding_window - - if hasattr(self.hf_text_config, "sliding_window"): - delattr(self.hf_text_config, "sliding_window") - - sliding_window = None + # Interleaved attention is not supported by some backends in V0 + if (not self.disable_sliding_window + and is_interleaved(self.hf_text_config) + and not envs.VLLM_USE_V1 + and (backend := envs.VLLM_ATTENTION_BACKEND) + in ("XFORMERS", "FLASHINFER")): + logger.warning_once( + "%s has interleaved attention, which is currently not " + "supported by the %s backend. Disabling sliding window and " + "capping the max length to the sliding window size (%d).", + self.hf_text_config.model_type, + backend, + self.hf_text_config.sliding_window, + ) + self.disable_sliding_window = True self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.multimodal_config = self._init_multimodal_config() + if self.disable_sliding_window: + # Set after get_and_verify_max_len to ensure that max_model_len + # can be correctly capped to sliding window size + self.hf_text_config.sliding_window = None + if not self.skip_tokenizer_init: self._verify_tokenizer_mode() @@ -1329,27 +1308,10 @@ def verify_with_parallel_config( if self.use_async_output_proc: self.use_async_output_proc = False - def get_hf_config_sliding_window( - self) -> Union[Optional[int], list[Optional[int]]]: - """Get the sliding window size, or None if disabled.""" - - # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in - # addition to sliding window size. We check if that field is present - # and if it's False, return None. - if (hasattr(self.hf_text_config, "use_sliding_window") - and not self.hf_text_config.use_sliding_window): - return None + def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size from the HF text config if present.""" return getattr(self.hf_text_config, "sliding_window", None) - def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: - """Get the sliding window size, or None if disabled. - """ - # If user disables sliding window, return None. - if self.disable_sliding_window: - return None - # Otherwise get the value from the hf config. - return self.get_hf_config_sliding_window() - def get_vocab_size(self) -> int: return getattr(self.hf_text_config, "vocab_size", 0) @@ -1769,7 +1731,7 @@ def get_and_verify_max_len(self, max_model_len: int): tokenizer_config=tokenizer_config, max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, - sliding_window_len=self.get_hf_config_sliding_window(), + sliding_window=self.get_sliding_window(), spec_target_max_model_len=self.spec_target_max_model_len, encoder_config=self.encoder_config) logger.info("Using max model len %s", max_model_len) @@ -3664,7 +3626,7 @@ def _get_and_verify_max_len( tokenizer_config: Optional[dict], max_model_len: Optional[int], disable_sliding_window: bool, - sliding_window_len: Optional[Union[int, list[Optional[int]]]], + sliding_window: Optional[int], spec_target_max_model_len: Optional[int] = None, encoder_config: Optional[Any] = None, ) -> int: @@ -3703,13 +3665,10 @@ def _get_and_verify_max_len( # If sliding window is manually disabled, max_length should be less # than the sliding window length in the model config. - if disable_sliding_window and sliding_window_len is not None: - - sliding_window_len_min = get_min_sliding_window(sliding_window_len) - max_len_key = "sliding_window" \ - if sliding_window_len_min < derived_max_model_len else max_len_key - derived_max_model_len = min(derived_max_model_len, - sliding_window_len_min) + if (disable_sliding_window and sliding_window is not None + and sliding_window < derived_max_model_len): + max_len_key = "sliding_window" + derived_max_model_len = sliding_window # Consider model_max_length in tokenizer_config if tokenizer_config: @@ -3810,14 +3769,6 @@ def _get_and_verify_max_len( return int(max_model_len) -def get_min_sliding_window( - sliding_window: Union[int, list[Optional[int]]]) -> int: - if isinstance(sliding_window, list): - return min(s for s in sliding_window if s is not None) - - return sliding_window - - def get_served_model_name(model: str, served_model_name: Optional[Union[str, list[str]]]): """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c9dc99cad2d0..27d3df867238 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -39,6 +39,7 @@ from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 +from vllm.transformers_utils.config import is_interleaved from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor) @@ -1090,6 +1091,13 @@ def create_engine_config( "DualChunkFlashAttention is not supported on V1 engine. " "To run the model in V0 engine, try set 'VLLM_USE_V1=0'") + sliding_window: Optional[int] = None + if not is_interleaved(model_config.hf_text_config): + # Only set CacheConfig.sliding_window if the model is all sliding + # window. Otherwise CacheConfig.sliding_window will override the + # global layers in interleaved sliding window models. + sliding_window = model_config.get_sliding_window() + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, @@ -1097,7 +1105,7 @@ def create_engine_config( cache_dtype=self.kv_cache_dtype, is_attention_free=model_config.is_attention_free, num_gpu_blocks_override=self.num_gpu_blocks_override, - sliding_window=model_config.get_sliding_window(), + sliding_window=sliding_window, enable_prefix_caching=self.enable_prefix_caching, prefix_caching_hash_algo=self.prefix_caching_hash_algo, cpu_offload_gb=self.cpu_offload_gb, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 69281abf730a..4dd84b8f8fdd 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -182,21 +182,13 @@ def __init__( ) # Model v2 has interleaved sliding windows, v1 does not - interleaved_sliding_window = getattr(config, - "interleaved_sliding_window", - None) - self.v1 = interleaved_sliding_window is None - - layer_idx = extract_layer_index(prefix) - layer_has_sliding_window = ( - getattr(config, "sliding_window_pattern", False) and - (layer_idx + 1) % self.config.sliding_window_pattern - != 0) or (getattr(config, "layer_types", False) - and config.layer_types[layer_idx] == "sliding_attention") - - self.sliding_window = (interleaved_sliding_window - or config.sliding_window - if layer_has_sliding_window else None) + self.v1 = isinstance(config, CohereConfig) + + self.sliding_window = None + if not self.v1: + layer_idx = extract_layer_index(prefix) + if config.layer_types[layer_idx] == "sliding_attention": + self.sliding_window = config.sliding_window self.attn = Attention(self.num_heads, self.head_dim, diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index ecd942a76ace..827e9014184b 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -159,25 +159,12 @@ def __init__( if quant_config is not None and quant_config.get_name() == "gguf": is_neox_style = False - self.apply_all_layers = False # apply rotary embeddings to every layer. layer_idx = extract_layer_index(prefix) - interleaved_sliding_window = getattr(config, - "interleaved_sliding_window", - 4096) - sliding_window_pattern = getattr(config, "sliding_window_pattern", - "LLLG") - - if sliding_window_pattern: - layer_has_sliding_window = ( - layer_idx + 1) % sliding_window_pattern.__len__() != 0 - else: - layer_has_sliding_window = False - self.apply_all_layers = True + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if is_sliding else None - if layer_has_sliding_window: - self.sliding_window = interleaved_sliding_window - else: - self.sliding_window = None + # apply rotary embeddings to every layer + self.apply_all_layers = not is_sliding self.rotary_emb = get_rope( self.head_dim, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 8beefb2cd0bd..8cfe92c64540 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -144,13 +144,10 @@ def __init__(self, is_neox_style=True, ) - # reference: - # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa layer_idx = extract_layer_index(prefix) - use_sliding_window = (layer_idx % 2 == 0 and getattr( - config, "interleaved_sliding_window", None) is not None) - sliding_window = config.interleaved_sliding_window if \ - use_sliding_window else None + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + sliding_window = config.sliding_window if is_sliding else None + self.attn = Attention(self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 1a2ce65d1e4c..b762be3c5292 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -146,25 +146,19 @@ def __init__(self, self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) - # TODO(woosuk): Add reference to the original HF implementation. layer_idx = extract_layer_index(prefix) - self.is_sliding = (getattr( - config, "interleaved_sliding_window", None) is not None and (bool( - (layer_idx + 1) % config.sliding_window_pattern))) or ( - getattr(config, "layer_types", None) is not None - and config.layer_types[layer_idx] == "sliding_attention") + self.is_sliding = config.layer_types[layer_idx] == "sliding_attention" + sliding_window = config.sliding_window if self.is_sliding else None + # Initialize the rotary embedding. if self.is_sliding: # Local attention. Override the values in config.json. self.rope_theta = config.rope_local_base_freq self.rope_scaling = {"rope_type": "default"} - self.sliding_window = (config.interleaved_sliding_window - or config.sliding_window) else: # Global attention. Use the values in config.json. self.rope_theta = config.rope_theta self.rope_scaling = config.rope_scaling - self.sliding_window = None self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, @@ -182,7 +176,7 @@ def __init__(self, cache_config=cache_config, quant_config=quant_config, logits_soft_cap=attn_logits_soft_cap, - per_layer_sliding_window=self.sliding_window, + per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn") def forward( diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index e9ee1ebdcc68..9871b11b3799 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -502,8 +502,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config self.multimodal_config = multimodal_config - self.sliding_window = getattr(config.text_config, - "interleaved_sliding_window", None) self.vision_tower = SiglipVisionModel(config.vision_config, quant_config, @@ -690,11 +688,11 @@ def prepare_attn_masks( global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) global_attn_masks.append(global_attn_mask) - if self.sliding_window is not None: + if (sliding_window := self.config.sliding_window) is not None: # Create a local causal mask with sliding window (1024). local_attn_mask = torch.ones_like(global_attn_mask) local_attn_mask = torch.tril(local_attn_mask, - diagonal=-self.sliding_window) + diagonal=-sliding_window) local_attn_mask = torch.where(local_attn_mask == 0, global_attn_mask, float("-inf")) local_attn_masks.append(local_attn_mask) diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index e16c03c8d3b5..7891b53dece4 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -313,17 +313,16 @@ def __init__(self, has_weight=False) layer_idx = extract_layer_index(prefix) + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if is_sliding else None - is_sliding_window = ( - getattr(config, "interleaved_sliding_window", None) is not None - and config.layer_types[layer_idx] == "sliding_attention") - - if is_sliding_window: - self.sliding_window = config.interleaved_sliding_window + # Initialize the rotary embedding. + if is_sliding: + # Local attention. Override the values in config.json. rope_theta = config.rope_local_base_freq rope_scaling = {"rope_type": "default"} else: - self.sliding_window = None + # Global attention. Use the values in config.json. rope_theta = config.rope_theta rope_scaling = config.rope_scaling diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index c99970284a95..9e7490e3c4f0 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -248,9 +248,7 @@ def __init__( vllm_config.cache_config.sliding_window = None - for attr in ("sliding_window", "interleaved_sliding_window"): - if hasattr(hf_config, attr): - delattr(hf_config, attr) + hf_config.sliding_window = None super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 48ec611df12d..bc511d833908 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -167,18 +167,11 @@ def __init__( rope_scaling=rope_scaling, quant_config=quant_config) - if hasattr(config, "interleaved_sliding_window"): - interleaved_sliding_window = config.interleaved_sliding_window - if isinstance(interleaved_sliding_window, int): - sliding_window = interleaved_sliding_window - elif isinstance(interleaved_sliding_window, list): - sw_idx = layer_idx % len(interleaved_sliding_window) - sliding_window = interleaved_sliding_window[sw_idx] - else: - raise ValueError( - f"{type(interleaved_sliding_window)} is not supported.") - else: - sliding_window = None + sliding_window = None + if layer_types := getattr(config, "layer_types", None): + is_sliding = layer_types[layer_idx] == "sliding_attention" + if is_sliding: + sliding_window = config.sliding_window self.attn = Attention( self.num_heads, diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index 1a761d01fc06..493a4192d35a 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -116,13 +116,8 @@ def __init__(self, self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True) # disable sliding window for the second half of the model - sliding_window = config.interleaved_sliding_window[layer_idx] - if layer_idx >= config.num_hidden_layers // 2: - assert sliding_window is None, \ - "sliding_window must be none for the second decoder" - else: - assert sliding_window is not None, \ - "sliding_window must be set for the first decoder" + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + sliding_window = config.sliding_window if is_sliding else None assert self.num_heads % 2 == 0, 'num_heads should be even' assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e4f0de04e9a1..7304fbf120cc 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -49,6 +49,7 @@ default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.config import is_interleaved from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, @@ -285,8 +286,7 @@ def __init__(self, quant_config = vllm_config.quant_config # TODO (@robertgshaw2): see if this can be moved out - if (cache_config.sliding_window is not None - and hasattr(config, "max_window_layers")): + if is_interleaved(vllm_config.model_config.hf_text_config): assert config.max_window_layers == config.num_hidden_layers, ( "Sliding window for some but all layers is not supported. " "This model uses sliding window but `max_window_layers` = {} " diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 92e132045c27..fc4585618b04 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -16,7 +16,7 @@ # limitations under the License. """Wrapper around `transformers` models""" from collections.abc import Iterable, Mapping -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager from typing import Literal, Optional, Union import regex as re @@ -382,33 +382,6 @@ def apply( ) -class ConfigOverride: - """Context manager to temporarily override config attributes.""" - - def __init__(self, config: PretrainedConfig, **kwargs): - self.config = config - self.kwargs = kwargs - self.kwargs_original = {} - self.kwargs_delete = set() - - def __enter__(self): - """Override config attributes.""" - for key, value in self.kwargs.items(): - if not hasattr(self.config, key): - self.kwargs_delete.add(key) - self.kwargs_original[key] = getattr(self.config, key, None) - setattr(self.config, key, value) - return self.config - - def __exit__(self, exc_type, exc_value, traceback): - """Restore original config attributes.""" - for key, value in self.kwargs_original.items(): - if key in self.kwargs_delete: - delattr(self.config, key) - else: - setattr(self.config, key, value) - - class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): embedding_padding_modules = ["lm_head"] embedding_modules = ["embed_tokens" @@ -434,21 +407,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # To be updated in child classes for use in `load_weights` self.skip_prefixes: Optional[list[str]] = None - # vLLM handles interleaved sliding window attention by creating a new - # interleaved_sliding_window attribute and deleting the sliding_window - # attribute. This breaks the constructors in Transformers so we - # temporarily add the attribute back to construct the model. - config_override = nullcontext() - if hasattr(self.config, "interleaved_sliding_window"): - config_override = ConfigOverride( - self.config, - sliding_window=self.config.interleaved_sliding_window) - # Set correct attn and init on "meta" to delay allocating GPU tensors # TODO: @raushan, use the public `model.set_attn_implementation()` # method once its checks are fixed in Transformers. self.text_config._attn_implementation = "vllm" - with init_on_device_without_buffers("meta"), config_override: + with init_on_device_without_buffers("meta"): self.model: PreTrainedModel = AutoModel.from_config( self.config, torch_dtype=self.model_config.dtype, @@ -575,11 +538,10 @@ def create_attention_instances(self) -> dict[int, Attention]: attention_instances = {} for i in range(start, end): # Handle interleaved sliding window attention - sliding_window = None - if (hasattr(self.config, "interleaved_sliding_window") - and hasattr(self.config, "sliding_window_pattern") - and ((i + 1) % self.config.sliding_window_pattern > 0)): - sliding_window = self.config.interleaved_sliding_window + per_layer_sliding_window = None + if (hasattr(self.config, "layer_types") + and self.config.layer_types[i] == "sliding_attention"): + per_layer_sliding_window = self.config.sliding_window attention_instances[i] = Attention( num_heads=num_heads, @@ -590,7 +552,7 @@ def create_attention_instances(self) -> dict[int, Attention]: num_kv_heads=num_kv_heads, cache_config=self.cache_config, quant_config=self.quant_config, - per_layer_sliding_window=sliding_window, + per_layer_sliding_window=per_layer_sliding_window, prefix=f"{i}.attn") return attention_instances diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index de779f94a4ab..6b70164c8caf 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -280,6 +280,17 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool: return getattr(config, "is_encoder_decoder", False) +def is_interleaved(config: PretrainedConfig) -> bool: + """ + Detect if the model with this config is used with interleaved attention. + """ + text_config = config.get_text_config() + if layer_types := getattr(text_config, "layer_types", None): + interleaved_types = {"full_attention", "sliding_attention"} + return interleaved_types.issubset(layer_types) + return False + + def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig: """Remap config attributes to match the expected names.""" for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items(): @@ -423,6 +434,23 @@ def get_config( raise e config = _maybe_remap_hf_config_attrs(config) + # Phi4Flash misuses this config as list[int]. Convert it to int and add + # the layer_types list[str] to make it HF compatible + if (config.model_type == "phi4flash"): + # TODO: Remove after the following PR is merged: + # https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/6 + if not hasattr(config, "layer_types"): + config.layer_types = [ + "sliding_attention" if i < config.num_hidden_layers // 2 + and i % 2 == 1 else "full_attention" + for i in range(config.num_hidden_layers) + ] + # TODO: Remove after the following PR is merged: + # https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/7 + if isinstance(config.sliding_window, list): + config.sliding_window = next( + filter(None, config.sliding_window), None) + elif config_format == ConfigFormat.MISTRAL: # This function loads a params.json config which # should be used when loading models in mistral format @@ -434,6 +462,18 @@ def get_config( config_dict["max_position_embeddings"] = max_position_embeddings config = adapt_config_dict(config_dict) + + # Mistral configs may define sliding_window as list[int]. Convert it + # to int and add the layer_types list[str] to make it HF compatible + if ((sliding_window := getattr(config, "sliding_window", None)) + and isinstance(sliding_window, list)): + pattern_repeats = config.num_hidden_layers // len(sliding_window) + layer_types = sliding_window * pattern_repeats + config.layer_types = [ + "full_attention" if layer_type is None else "sliding_attention" + for layer_type in layer_types + ] + config.sliding_window = next(filter(None, sliding_window), None) else: supported_formats = [ fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO