Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5dc61b4
Refactor sliding window configuration to Transformers best practice
hmellor Jul 30, 2025
7c2fa94
Fix loading of mistral format configs
hmellor Jul 30, 2025
b083a60
Phi4flash is a custom model that misuses `sliding_window`
hmellor Jul 30, 2025
e6c7ebc
Merge branch 'main' into refactor-sliding-window-config
hmellor Jul 30, 2025
4657080
Merge branch 'main' into refactor-sliding-window-config
hmellor Jul 31, 2025
98f5365
Merge branch 'main' into refactor-sliding-window-config
hmellor Aug 1, 2025
10bb15c
Delay disabling of sliding window until max_model_len has been inferred
hmellor Aug 1, 2025
dabdc4a
Remove complicated getters which are no longer necessary
hmellor Aug 1, 2025
53f037d
typo
hmellor Aug 1, 2025
049f2bf
Merge branch 'main' into refactor-sliding-window-config
hmellor Aug 1, 2025
22821b7
Add `ModelConfig.is_interleaved` property and use it where we can
hmellor Aug 1, 2025
fbcd91c
Add comment
hmellor Aug 1, 2025
3d893e0
Move is_interleaved to a transformers util
hmellor Aug 1, 2025
f12e4f9
Merge branch 'main' into refactor-sliding-window-config
hmellor Aug 6, 2025
597dafc
Standardize Phi4Flash
hmellor Aug 6, 2025
d350535
Add workarounds for Phi4Flash
hmellor Aug 6, 2025
30bd425
Typo
hmellor Aug 6, 2025
9bb4299
Make workaround compatible before and after HF PRs
hmellor Aug 6, 2025
ac5f9bd
Less ambiguous variable naming
hmellor Aug 6, 2025
ad289a5
Merge branch 'main' into refactor-sliding-window-config
hmellor Aug 6, 2025
e1d837a
Merge branch 'main' into refactor-sliding-window-config
hmellor Aug 7, 2025
215626c
Merge branch 'main' into refactor-sliding-window-config
hmellor Aug 7, 2025
9cd23ea
Merge branch 'main' into refactor-sliding-window-config
hmellor Aug 8, 2025
068e134
Merge branch 'main' into refactor-sliding-window-config
hmellor Aug 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/contributing/model/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m

To support a model with interleaving sliding windows, we need to take care of the following details:

- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model.
- Make sure the model's `config.json` contains `layer_types`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hmellor A dumb question: Many of the existing models don't have layer_types in their config.json. How do we handle them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config classes in Transformers have been updated to use layer_types.

If these config classes load an older checkpoint that doesn't contain layer_types, the config class will create it using the deprecated fields available in config.json.

This way, BC is handled by Transformers, not us.

- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).

With these two steps, interleave sliding windows should work with the model.
22 changes: 0 additions & 22 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,28 +200,6 @@ def test_disable_sliding_window(model_id_expected):
assert model_config.max_model_len == expected


def test_get_sliding_window():
TEST_SLIDING_WINDOW = 4096
# Test that the sliding window is correctly computed.
# For Qwen1.5/Qwen2, get_sliding_window() should be None
# when use_sliding_window is False.
qwen2_model_config = ModelConfig("Qwen/Qwen1.5-7B")

qwen2_model_config.hf_config.use_sliding_window = False
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert qwen2_model_config.get_sliding_window() is None

qwen2_model_config.hf_config.use_sliding_window = True
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW

mistral_model_config = ModelConfig("mistralai/Mistral-7B-v0.1")
mistral_model_config.hf_config.sliding_window = None
assert mistral_model_config.get_sliding_window() is None

mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():
Expand Down
111 changes: 31 additions & 80 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
maybe_override_with_speculators_target_model, try_get_generation_config,
try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope)
is_interleaved, maybe_override_with_speculators_target_model,
try_get_generation_config, try_get_safetensors_metadata,
try_get_tokenizer_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
# yapf conflicts with isort for this block
Expand Down Expand Up @@ -721,53 +722,31 @@ def _task_to_convert(task: TaskOption) -> ConvertType:
revision=self.revision,
)

# Workaround for Gemma 2 which uses interleaved sliding window
# attention, but it's not specified in its config.
# TODO: remove this when Gemma 2 config updated in HuggingFace.
if self.hf_text_config.model_type == "gemma2":
self.hf_text_config.sliding_window_pattern = 2

# TODO: remove this when Gemma 3n config updated in HuggingFace.
if self.hf_text_config.model_type == "gemma3n_text":
# 4 sliding window attention followed by 1 full attention
self.hf_text_config.sliding_window_pattern = "LLLLG"

sliding_window = getattr(self.hf_text_config, "sliding_window", None)
sliding_window_pattern = getattr(self.hf_text_config,
"sliding_window_pattern", None)
has_interleaved_attention = sliding_window_pattern is not None or (
isinstance(sliding_window, list))

if not self.disable_sliding_window and has_interleaved_attention:
if not envs.VLLM_USE_V1 and (backend := envs.VLLM_ATTENTION_BACKEND
) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)

logger.warning_once(
"%s has interleaved attention, which is currently not supported by the %s backend. Disabling sliding window and capping the max length to the sliding window size (%d).", # noqa: E501
self.hf_text_config.model_type,
backend,
sliding_window_len_min,
)
self.disable_sliding_window = True
else:
# for a model with interleaved attention,
# the scheduler and the model treat it as full attention
# (i.e., not dropping any tokens outside the window).
# only the attention layer itself is aware of the sliding
# window, and use the window size to compute the attention.
self.hf_text_config.interleaved_sliding_window = sliding_window

if hasattr(self.hf_text_config, "sliding_window"):
delattr(self.hf_text_config, "sliding_window")

sliding_window = None
# Interleaved attention is not supported by some backends in V0
if (not self.disable_sliding_window
and is_interleaved(self.hf_text_config)
and not envs.VLLM_USE_V1
and (backend := envs.VLLM_ATTENTION_BACKEND)
in ("XFORMERS", "FLASHINFER")):
logger.warning_once(
"%s has interleaved attention, which is currently not "
"supported by the %s backend. Disabling sliding window and "
"capping the max length to the sliding window size (%d).",
self.hf_text_config.model_type,
backend,
self.hf_text_config.sliding_window,
)
self.disable_sliding_window = True

self.original_max_model_len = self.max_model_len
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
self.multimodal_config = self._init_multimodal_config()

if self.disable_sliding_window:
# Set after get_and_verify_max_len to ensure that max_model_len
# can be correctly capped to sliding window size
self.hf_text_config.sliding_window = None

if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()

Expand Down Expand Up @@ -1329,27 +1308,10 @@ def verify_with_parallel_config(
if self.use_async_output_proc:
self.use_async_output_proc = False

def get_hf_config_sliding_window(
self) -> Union[Optional[int], list[Optional[int]]]:
"""Get the sliding window size, or None if disabled."""

# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
# and if it's False, return None.
if (hasattr(self.hf_text_config, "use_sliding_window")
and not self.hf_text_config.use_sliding_window):
return None
def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size from the HF text config if present."""
return getattr(self.hf_text_config, "sliding_window", None)

def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
if self.disable_sliding_window:
return None
# Otherwise get the value from the hf config.
return self.get_hf_config_sliding_window()

def get_vocab_size(self) -> int:
return getattr(self.hf_text_config, "vocab_size", 0)

Expand Down Expand Up @@ -1769,7 +1731,7 @@ def get_and_verify_max_len(self, max_model_len: int):
tokenizer_config=tokenizer_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
sliding_window=self.get_sliding_window(),
spec_target_max_model_len=self.spec_target_max_model_len,
encoder_config=self.encoder_config)
logger.info("Using max model len %s", max_model_len)
Expand Down Expand Up @@ -3664,7 +3626,7 @@ def _get_and_verify_max_len(
tokenizer_config: Optional[dict],
max_model_len: Optional[int],
disable_sliding_window: bool,
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
sliding_window: Optional[int],
spec_target_max_model_len: Optional[int] = None,
encoder_config: Optional[Any] = None,
) -> int:
Expand Down Expand Up @@ -3703,13 +3665,10 @@ def _get_and_verify_max_len(

# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
if disable_sliding_window and sliding_window_len is not None:

sliding_window_len_min = get_min_sliding_window(sliding_window_len)
max_len_key = "sliding_window" \
if sliding_window_len_min < derived_max_model_len else max_len_key
derived_max_model_len = min(derived_max_model_len,
sliding_window_len_min)
if (disable_sliding_window and sliding_window is not None
and sliding_window < derived_max_model_len):
max_len_key = "sliding_window"
derived_max_model_len = sliding_window

# Consider model_max_length in tokenizer_config
if tokenizer_config:
Expand Down Expand Up @@ -3810,14 +3769,6 @@ def _get_and_verify_max_len(
return int(max_model_len)


def get_min_sliding_window(
sliding_window: Union[int, list[Optional[int]]]) -> int:
if isinstance(sliding_window, list):
return min(s for s in sliding_window if s is not None)

return sliding_window


def get_served_model_name(model: str,
served_model_name: Optional[Union[str, list[str]]]):
"""
Expand Down
10 changes: 9 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from vllm.ray.lazy_utils import is_ray_initialized
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.config import is_interleaved
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, get_ip, is_in_ray_actor)
Expand Down Expand Up @@ -1090,14 +1091,21 @@ def create_engine_config(
"DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")

sliding_window: Optional[int] = None
if not is_interleaved(model_config.hf_text_config):
# Only set CacheConfig.sliding_window if the model is all sliding
# window. Otherwise CacheConfig.sliding_window will override the
# global layers in interleaved sliding window models.
sliding_window = model_config.get_sliding_window()

cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,
is_attention_free=model_config.is_attention_free,
num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=model_config.get_sliding_window(),
sliding_window=sliding_window,
enable_prefix_caching=self.enable_prefix_caching,
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb,
Expand Down
22 changes: 7 additions & 15 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,13 @@ def __init__(
)

# Model v2 has interleaved sliding windows, v1 does not
interleaved_sliding_window = getattr(config,
"interleaved_sliding_window",
None)
self.v1 = interleaved_sliding_window is None

layer_idx = extract_layer_index(prefix)
layer_has_sliding_window = (
getattr(config, "sliding_window_pattern", False) and
(layer_idx + 1) % self.config.sliding_window_pattern
!= 0) or (getattr(config, "layer_types", False)
and config.layer_types[layer_idx] == "sliding_attention")

self.sliding_window = (interleaved_sliding_window
or config.sliding_window
if layer_has_sliding_window else None)
self.v1 = isinstance(config, CohereConfig)

self.sliding_window = None
if not self.v1:
layer_idx = extract_layer_index(prefix)
if config.layer_types[layer_idx] == "sliding_attention":
self.sliding_window = config.sliding_window

self.attn = Attention(self.num_heads,
self.head_dim,
Expand Down
21 changes: 4 additions & 17 deletions vllm/model_executor/models/exaone4.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,25 +159,12 @@ def __init__(
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False

self.apply_all_layers = False # apply rotary embeddings to every layer.
layer_idx = extract_layer_index(prefix)
interleaved_sliding_window = getattr(config,
"interleaved_sliding_window",
4096)
sliding_window_pattern = getattr(config, "sliding_window_pattern",
"LLLG")

if sliding_window_pattern:
layer_has_sliding_window = (
layer_idx + 1) % sliding_window_pattern.__len__() != 0
else:
layer_has_sliding_window = False
self.apply_all_layers = True
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
self.sliding_window = config.sliding_window if is_sliding else None

if layer_has_sliding_window:
self.sliding_window = interleaved_sliding_window
else:
self.sliding_window = None
# apply rotary embeddings to every layer
self.apply_all_layers = not is_sliding

self.rotary_emb = get_rope(
self.head_dim,
Expand Down
9 changes: 3 additions & 6 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,10 @@ def __init__(self,
is_neox_style=True,
)

# reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
layer_idx = extract_layer_index(prefix)
use_sliding_window = (layer_idx % 2 == 0 and getattr(
config, "interleaved_sliding_window", None) is not None)
sliding_window = config.interleaved_sliding_window if \
use_sliding_window else None
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
sliding_window = config.sliding_window if is_sliding else None

self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
Expand Down
14 changes: 4 additions & 10 deletions vllm/model_executor/models/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,25 +146,19 @@ def __init__(self,
self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)

# TODO(woosuk): Add reference to the original HF implementation.
layer_idx = extract_layer_index(prefix)
self.is_sliding = (getattr(
config, "interleaved_sliding_window", None) is not None and (bool(
(layer_idx + 1) % config.sliding_window_pattern))) or (
getattr(config, "layer_types", None) is not None
and config.layer_types[layer_idx] == "sliding_attention")
self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
sliding_window = config.sliding_window if self.is_sliding else None

# Initialize the rotary embedding.
if self.is_sliding:
# Local attention. Override the values in config.json.
self.rope_theta = config.rope_local_base_freq
self.rope_scaling = {"rope_type": "default"}
self.sliding_window = (config.interleaved_sliding_window
or config.sliding_window)
else:
# Global attention. Use the values in config.json.
self.rope_theta = config.rope_theta
self.rope_scaling = config.rope_scaling
self.sliding_window = None
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
Expand All @@ -182,7 +176,7 @@ def __init__(self,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap,
per_layer_sliding_window=self.sliding_window,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn")

def forward(
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.sliding_window = getattr(config.text_config,
"interleaved_sliding_window", None)

self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config,
Expand Down Expand Up @@ -690,11 +688,11 @@ def prepare_attn_masks(
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)

if self.sliding_window is not None:
if (sliding_window := self.config.sliding_window) is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask,
diagonal=-self.sliding_window)
diagonal=-sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf"))
local_attn_masks.append(local_attn_mask)
Expand Down
13 changes: 6 additions & 7 deletions vllm/model_executor/models/gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,17 +313,16 @@ def __init__(self,
has_weight=False)

layer_idx = extract_layer_index(prefix)
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
self.sliding_window = config.sliding_window if is_sliding else None

is_sliding_window = (
getattr(config, "interleaved_sliding_window", None) is not None
and config.layer_types[layer_idx] == "sliding_attention")

if is_sliding_window:
self.sliding_window = config.interleaved_sliding_window
# Initialize the rotary embedding.
if is_sliding:
# Local attention. Override the values in config.json.
rope_theta = config.rope_local_base_freq
rope_scaling = {"rope_type": "default"}
else:
self.sliding_window = None
# Global attention. Use the values in config.json.
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling

Expand Down
Loading