From ff93550eb723d32a017b312fe0ef5017b98f4882 Mon Sep 17 00:00:00 2001 From: Juechen Liu Date: Fri, 19 Sep 2025 22:22:31 -0700 Subject: [PATCH 01/17] expose requests preemptions to ods Summary: Currently when no new blocks available from each step, we already record this as [request events](https://fburl.com/code/rsiolx07) and set it back to engine client by EngineCoreResponses which later got [aggregated](https://fburl.com/code/82r3x1lw) in the [iteration stats](https://fburl.com/code/lw96wgom). In this diff, we just expose this to ODS via MetaStatLoggerV1 thus we get the counter exposed in the background. The reason we want this counter is to measure num requests preemptions when kv cache is saturated. Test Plan: run locally, saturate cache usage to 100%, we are able to see "llm.vllm.request.preemptions" popped up {F1982066617} Differential Revision: D82650207 --- vllm/v1/metrics/loggers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index f0076b2d81db..d1ce27ca1541 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -72,11 +72,13 @@ def _reset(self, now): # Tracked stats over current local logging interval. self.num_prompt_tokens: int = 0 self.num_generation_tokens: int = 0 + self.num_preempted_reqs: int = 0 def _track_iteration_stats(self, iteration_stats: IterationStats): # Save tracked stats for token counters. self.num_prompt_tokens += iteration_stats.num_prompt_tokens self.num_generation_tokens += iteration_stats.num_generation_tokens + self.num_preempted_reqs += iteration_stats.num_preempted_reqs def _get_throughput(self, tracked_stats: int, now: float) -> float: # Compute summary metrics for tracked stats From c3f7ed3f77eeff5793da730e7d8c9369fad51606 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 21 Sep 2025 04:05:20 -0700 Subject: [PATCH 02/17] [MM][Perf] Minor Optimization on Qwen3-VL `fast_pos_embed_interpolate` (#25337) Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 135 +++++++++++-------------- 1 file changed, 60 insertions(+), 75 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 17375ff0959d..ca232e03767b 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -270,6 +270,7 @@ def __init__( self.temporal_patch_size = vision_config.temporal_patch_size self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes self.use_data_parallel = use_data_parallel + self.num_grid_per_side = int(self.num_position_embeddings**0.5) # NOTE: This is used for creating empty tensor for all_gather for # DP ViT. Here out_hidden_size is enlarged due to deepstack @@ -377,82 +378,68 @@ def rot_pos_emb(self, grid_thw): rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - def fast_pos_embed_interpolate(self, grid_thw): - num_grid_per_side = int(self.num_position_embeddings**0.5) + def fast_pos_embed_interpolate(self, + grid_thw: list[list[int]]) -> torch.Tensor: - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] + num_grid_per_side = self.num_grid_per_side + m_size = self.spatial_merge_size + hidden_dim = self.pos_embed.embedding_dim + outputs = [] for t, h, w in grid_thw: h_idxs = torch.linspace(0, num_grid_per_side - 1, h, - dtype=torch.float32) + dtype=torch.float32, + device=self.device) w_idxs = torch.linspace(0, num_grid_per_side - 1, w, - dtype=torch.float32) - - h_idxs_floor = h_idxs.to(torch.long) - w_idxs_floor = w_idxs.to(torch.long) - h_idxs_ceil = torch.clamp(h_idxs.to(torch.long) + 1, - max=num_grid_per_side - 1) - w_idxs_ceil = torch.clamp(w_idxs.to(torch.long) + 1, - max=num_grid_per_side - 1) - - dh = h_idxs - h_idxs_floor - dw = w_idxs - w_idxs_floor - - idx_list[0].extend(((h_idxs_floor * num_grid_per_side)[None].T + - w_idxs_floor[None]).flatten().tolist() * t) - idx_list[1].extend(((h_idxs_floor * num_grid_per_side)[None].T + - w_idxs_ceil[None]).flatten().tolist() * t) - idx_list[2].extend(((h_idxs_ceil * num_grid_per_side)[None].T + - w_idxs_floor[None]).flatten().tolist() * t) - idx_list[3].extend(((h_idxs_ceil * num_grid_per_side)[None].T + - w_idxs_ceil[None]).flatten().tolist() * t) - - weight_list[0].extend( - ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t) - weight_list[1].extend( - ((1 - dh)[None].T * dw[None]).flatten().tolist() * t) - weight_list[2].extend( - (dh[None].T * (1 - dw)[None]).flatten().tolist() * t) - weight_list[3].extend( - (dh[None].T * dw[None]).flatten().tolist() * t) - - device = self.pos_embed.weight.device - dtype = self.pos_embed.weight.dtype - - p0 = self.pos_embed( - torch.tensor( - idx_list[0], dtype=torch.long, device=device)) * torch.tensor( - weight_list[0], dtype=dtype, device=device)[:, None] - p1 = self.pos_embed( - torch.tensor( - idx_list[1], dtype=torch.long, device=device)) * torch.tensor( - weight_list[1], dtype=dtype, device=device)[:, None] - p2 = self.pos_embed( - torch.tensor( - idx_list[2], dtype=torch.long, device=device)) * torch.tensor( - weight_list[2], dtype=dtype, device=device)[:, None] - p3 = self.pos_embed( - torch.tensor( - idx_list[3], dtype=torch.long, device=device)) * torch.tensor( - weight_list[3], dtype=dtype, device=device)[:, None] - - patch_pos_embeds = p0 + p1 + p2 + p3 - patch_pos_embeds = patch_pos_embeds.split( - [t * h * w for t, h, w in grid_thw]) - patch_pos_embeds_permute = [] - m_size = self.spatial_merge_size - for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw): - pos_embed = pos_embed.view(t, h // m_size, m_size, w // m_size, - m_size, -1).permute(0, 1, 3, 2, 4, - 5).flatten(0, 4) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = torch.cat(patch_pos_embeds_permute) - return patch_pos_embeds + dtype=torch.float32, + device=self.device) + + h_floor = h_idxs.to(torch.long) + w_floor = w_idxs.to(torch.long) + h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1) + w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1) + + dh = h_idxs - h_floor + dw = w_idxs - w_floor + + w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1) + w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1) + w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1) + w11 = (dh[:, None] * dw[None, :]).reshape(-1) + + idx00 = (h_floor[:, None] * num_grid_per_side + + w_floor[None, :]).reshape(-1) + idx01 = (h_floor[:, None] * num_grid_per_side + + w_ceil[None, :]).reshape(-1) + idx10 = (h_ceil[:, None] * num_grid_per_side + + w_floor[None, :]).reshape(-1) + idx11 = (h_ceil[:, None] * num_grid_per_side + + w_ceil[None, :]).reshape(-1) + + indices = torch.stack([idx00, idx01, idx10, idx11], dim=0) + weights = torch.stack([w00, w01, w10, w11], + dim=0).to(dtype=self.dtype, + device=self.device) + weights = weights.unsqueeze(-1) + + embeds = self.pos_embed(indices) + weighted_embeds = embeds * weights + p0, p1, p2, p3 = weighted_embeds.unbind(dim=0) + combined = p0 + p1 + p2 + p3 + + combined = combined.view(h * w, hidden_dim) + repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous() + repeated = repeated.view(t, h // m_size, m_size, w // m_size, + m_size, hidden_dim) + repeated = repeated.permute(0, 1, 3, 2, 4, + 5).reshape(-1, hidden_dim) + outputs.append(repeated) + + return torch.cat(outputs, dim=0) def compute_attn_mask_seqlen( self, @@ -477,12 +464,9 @@ def forward( hidden_states = hidden_states + pos_embeds rotary_pos_emb = self.rot_pos_emb(grid_thw) - if isinstance(grid_thw, list): - grid_thw_tensor = torch.tensor(grid_thw, - device=hidden_states.device, - dtype=torch.int32) - else: - grid_thw_tensor = grid_thw + grid_thw_tensor = torch.tensor(grid_thw, + device=self.device, + dtype=torch.int32) cu_seqlens = torch.repeat_interleave( grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2], @@ -1224,7 +1208,8 @@ def _process_image_input( grid_thw_list, rope_type="rope_3d") else: - image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + image_embeds = self.visual(pixel_values, + grid_thw=grid_thw_list) # Split concatenated embeddings for each image item. # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync @@ -1526,4 +1511,4 @@ def get_mm_mapping(self) -> MultiModelKeys: language_model="language_model", connector="model.visual.merger", tower_model="model.visual.", - ) \ No newline at end of file + ) From a2c21a26492f302c267c7144be77f55425a33e4a Mon Sep 17 00:00:00 2001 From: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Date: Sun, 21 Sep 2025 13:36:47 +0200 Subject: [PATCH 03/17] [Bugfix] Typos in error message for missing model config file (#25339) Signed-off-by: simondanielsson --- vllm/transformers_utils/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index cafc43f6b767..52e2c18a7784 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -524,10 +524,10 @@ def get_config( else: raise ValueError( "Could not detect config format for no config file found. " - "With config_format 'auto', ensure your model has either" - "config.json (HF format) or params.json (Mistral format)." - "Otherwise please specify your_custom_config_format" - "in engine args for customized config parser") + "With config_format 'auto', ensure your model has either " + "config.json (HF format) or params.json (Mistral format). " + "Otherwise please specify your_custom_config_format " + "in engine args for customized config parser.") except Exception as e: error_message = ( From 6d949dbd68ec3fbc1ab985db5fac968ab19080a1 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 21 Sep 2025 19:41:02 +0800 Subject: [PATCH 04/17] [Optimization] Cache chat template result when processor fails to be loaded (#25341) Signed-off-by: DarkLight1337 --- vllm/entrypoints/chat_utils.py | 71 +++++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 22 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c2c0ad74ef43..df49119d8642 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -421,6 +421,51 @@ def resolve_mistral_chat_template( return None +_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], Optional[str]]() +""" +Used in `_try_get_processor_chat_template` to avoid calling +`cached_get_processor` again if the processor fails to be loaded. + +This is needed because `lru_cache` does not cache when an exception happens. +""" + + +def _try_get_processor_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + model_config: ModelConfig, +) -> Optional[str]: + cache_key = (tokenizer.name_or_path, model_config.trust_remote_code) + if cache_key in _PROCESSOR_CHAT_TEMPLATES: + return _PROCESSOR_CHAT_TEMPLATES[cache_key] + + try: + processor = cached_get_processor( + tokenizer.name_or_path, + processor_cls=( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, + ), + trust_remote_code=model_config.trust_remote_code, + ) + if ( + isinstance(processor, ProcessorMixin) + and hasattr(processor, "chat_template") + and (chat_template := processor.chat_template) is not None + ): + _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template + return chat_template + except Exception: + logger.debug( + "Failed to load AutoProcessor chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) + + _PROCESSOR_CHAT_TEMPLATES[cache_key] = None + return None + + def resolve_hf_chat_template( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], chat_template: Optional[str], @@ -434,28 +479,10 @@ def resolve_hf_chat_template( # 2nd priority: AutoProcessor chat template, unless tool calling is enabled if tools is None: - try: - processor = cached_get_processor( - tokenizer.name_or_path, - processor_cls=( - PreTrainedTokenizer, - PreTrainedTokenizerFast, - ProcessorMixin, - ), - trust_remote_code=model_config.trust_remote_code, - ) - if ( - isinstance(processor, ProcessorMixin) - and hasattr(processor, "chat_template") - and processor.chat_template is not None - ): - return processor.chat_template - except Exception: - logger.debug( - "Failed to load AutoProcessor chat template for %s", - tokenizer.name_or_path, - exc_info=True, - ) # noqa: E501 + chat_template = _try_get_processor_chat_template(tokenizer, + model_config) + if chat_template is not None: + return chat_template # 3rd priority: AutoTokenizer chat template try: From b2e5dc1d079d6ccd5368d664bd32a0415e765fbd Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Sep 2025 08:52:15 -0700 Subject: [PATCH 05/17] [V0 Deprecation] Remove V0 Sequence class & Sampler (#25332) Signed-off-by: Woosuk Kwon Signed-off-by: Woosuk Kwon --- tests/conftest.py | 2 +- .../generation/test_granite_speech.py | 2 +- .../multimodal/generation/test_phi4mm.py | 2 +- .../multimodal/generation/test_pixtral.py | 2 +- .../generation/vlm_utils/model_utils.py | 2 +- .../multimodal/generation/vlm_utils/types.py | 2 +- tests/models/utils.py | 2 +- tests/tokenization/test_detokenize.py | 140 +- tests/tool_use/test_jamba_tool_parser.py | 2 +- tests/tool_use/test_qwen3coder_tool_parser.py | 2 +- tests/tool_use/test_seed_oss_tool_parser.py | 2 +- tests/tool_use/test_xlam_tool_parser.py | 2 +- tests/v1/engine/test_output_processor.py | 2 +- vllm/executor/executor_base.py | 2 +- vllm/executor/ray_distributed_executor.py | 2 +- vllm/inputs/__init__.py | 13 +- vllm/inputs/registry.py | 67 +- vllm/model_executor/__init__.py | 4 +- .../model_executor/layers/logits_processor.py | 91 -- vllm/model_executor/layers/sampler.py | 1198 --------------- vllm/model_executor/models/medusa.py | 60 +- vllm/model_executor/models/mlp_speculator.py | 80 +- vllm/model_executor/models/phi4flash.py | 6 +- vllm/model_executor/sampling_metadata.py | 594 +------- vllm/sequence.py | 1322 +---------------- vllm/transformers_utils/detokenizer.py | 162 -- vllm/worker/worker_base.py | 2 +- 27 files changed, 70 insertions(+), 3697 deletions(-) delete mode 100644 vllm/model_executor/layers/sampler.py delete mode 100644 vllm/transformers_utils/detokenizer.py diff --git a/tests/conftest.py b/tests/conftest.py index f14b1e8780ad..dc70c9835959 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -48,10 +48,10 @@ initialize_model_parallel) from vllm.inputs import TextPrompt from vllm.logger import init_logger +from vllm.logprobs import Logprob from vllm.multimodal.utils import fetch_image from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams -from vllm.sequence import Logprob from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils import set_default_torch_num_threads diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index f2e6fbfad6e8..c1305e0ae31c 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -7,8 +7,8 @@ import pytest from transformers import AutoModelForSpeechSeq2Seq +from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest -from vllm.sequence import SampleLogprobs from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput, VllmRunner) diff --git a/tests/models/multimodal/generation/test_phi4mm.py b/tests/models/multimodal/generation/test_phi4mm.py index 67d35213d642..77e2b90dd5e9 100644 --- a/tests/models/multimodal/generation/test_phi4mm.py +++ b/tests/models/multimodal/generation/test_phi4mm.py @@ -12,10 +12,10 @@ from transformers import AutoTokenizer from vllm.assets.image import ImageAsset +from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal.image import convert_image_mode, rescale_image_size from vllm.platforms import current_platform -from vllm.sequence import SampleLogprobs from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput, PromptImageInput, VllmRunner) diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index cb3cc1d3d330..715b08ef90e5 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -13,8 +13,8 @@ from transformers import AutoProcessor from vllm import SamplingParams, TextPrompt, TokensPrompt +from vllm.logprobs import Logprob, SampleLogprobs from vllm.multimodal import MultiModalDataBuiltins -from vllm.sequence import Logprob, SampleLogprobs from ....utils import VLLM_PATH, large_gpu_test from ...utils import check_logprobs_close diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 8b7d051218f1..ba55450ec8a9 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -19,7 +19,7 @@ GenerationConfig, GenerationMixin) from transformers.video_utils import VideoMetadata -from vllm.sequence import SampleLogprobs +from vllm.logprobs import SampleLogprobs from vllm.utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets diff --git a/tests/models/multimodal/generation/vlm_utils/types.py b/tests/models/multimodal/generation/vlm_utils/types.py index 945113196088..e39ca40fbbf5 100644 --- a/tests/models/multimodal/generation/vlm_utils/types.py +++ b/tests/models/multimodal/generation/vlm_utils/types.py @@ -12,7 +12,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm.config import RunnerOption -from vllm.sequence import SampleLogprobs +from vllm.logprobs import SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer from .....conftest import (AUDIO_ASSETS, IMAGE_ASSETS, HfRunner, ImageAsset, diff --git a/tests/models/utils.py b/tests/models/utils.py index 76c6e4823a12..5da2382cef81 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -12,7 +12,7 @@ from vllm.config import ModelConfig, ModelDType, RunnerOption from vllm.inputs import InputContext -from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs +from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs from .registry import HF_EXAMPLE_MODELS diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index bd2b91073d56..fe6c313d2966 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -8,10 +8,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) -from vllm.inputs import token_inputs -from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup -from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, @@ -217,138 +214,3 @@ def test_oov_decode(tokenizer, fast): assert decoded_text == '' assert out_ids == [len(tokenizer)] - - -@pytest.fixture -def detokenizer(tokenizer_name: str) -> Detokenizer: - tokenizer = get_tokenizer( - tokenizer_name, - tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto", - trust_remote_code=False, - revision=None, - ) - - return Detokenizer(tokenizer) - - -@pytest.fixture(name="complete_sequence_token_ids") -def create_complete_sequence_token_ids(complete_sequence: str, - tokenizer) -> list[int]: - return tokenizer(complete_sequence, add_special_tokens=False).input_ids - - -def create_sequence(prompt_token_ids=None): - prompt_token_ids = prompt_token_ids or [] - return Sequence( - seq_id=0, - inputs=token_inputs(prompt_token_ids), - block_size=16, - ) - - -def create_dummy_logprobs( - complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]: - return [{ - token_id: Logprob(logprob=0.0), - token_id + 1: Logprob(logprob=0.1) - } for token_id in complete_sequence_token_ids] - - -def create_dummy_prompt_logprobs( - complete_sequence_token_ids: list[int] -) -> list[Optional[dict[int, Any]]]: - # logprob for the first prompt token is None. - logprobs: list[Optional[dict[int, Any]]] = [None] - logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:]) - return logprobs - - -@pytest.mark.parametrize("complete_sequence", TRUTH) -@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True) -def test_decode_sequence_logprobs(complete_sequence: str, - complete_sequence_token_ids: list[int], - detokenizer: Detokenizer, - skip_special_tokens: bool): - """Verify Detokenizer decodes logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, - logprobs=2) - - # Run sequentially. - seq = create_sequence() - dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) - sequential_logprobs_text_chosen_token: list[str] = [] - sequential_logprobs_text_other_token: list[str] = [] - for new_token, logprobs in zip(complete_sequence_token_ids, - dummy_logprobs): - seq.append_token_id(new_token, logprobs) - detokenizer.decode_sequence_inplace(seq, sampling_params) - sequential_logprobs_text_chosen_token.append( - seq.output_logprobs[-1][new_token].decoded_token) - sequential_logprobs_text_other_token.append( - seq.output_logprobs[-1][new_token + 1].decoded_token) - sequential_result = seq.output_text - - assert sequential_result == "".join(sequential_logprobs_text_chosen_token) - assert sequential_result != "".join(sequential_logprobs_text_other_token) - - if not skip_special_tokens: - # Text for logprobs for the chosen token should be the same as the - # generated text. Note that this will only be true if we skip - # special tokens. - assert sequential_result == complete_sequence - - -@pytest.mark.parametrize("complete_sequence", TRUTH) -@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -def test_decode_prompt_logprobs(complete_sequence: str, - complete_sequence_token_ids: list[int], - detokenizer: Detokenizer): - - # We want to use skip_special_tokens=False here but Mistral tokenizers - # don't support that. - if complete_sequence not in SPECIAL_TOKS_TRUTH: - skip_special_tokens = True - elif not isinstance(detokenizer.tokenizer, MistralTokenizer): - skip_special_tokens = False - else: - pytest.skip("MistralTokenizers don't support " - "skip_special_tokens=False") - return - """Verify Detokenizer decodes prompt logprobs correctly.""" - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, - prompt_logprobs=1) - - # Run sequentially. - seq = create_sequence(complete_sequence_token_ids) - seq_group = SequenceGroup(request_id="1", - seqs=[seq], - sampling_params=sampling_params, - arrival_time=0.0) - dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids) - detokenizer.decode_prompt_logprobs_inplace(seq_group, - dummy_logprobs, - position_offset=0) - # First logprob is None. - decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[ - 1:] # type: ignore - - # decoded_prompt_logprobs doesn't contain the first token. - token_ids = complete_sequence_token_ids - tokenizer = detokenizer.tokenizer - text_full = tokenizer.decode(token_ids, - skip_special_tokens=skip_special_tokens) - text_first = tokenizer.decode(token_ids[0], - skip_special_tokens=skip_special_tokens) - text = text_full[len(text_first):] - - # Text for logprobs for the chosen token should be the same as the - # prompt text. Note that the first logprob is None. - assert text == "".join([ - logprobs[token_id].decoded_token - for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) - ]) - assert text != "".join([ - logprobs[token_id + 1].decoded_token - for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) - ]) diff --git a/tests/tool_use/test_jamba_tool_parser.py b/tests/tool_use/test_jamba_tool_parser.py index 35153139350b..57ace1fa22ac 100644 --- a/tests/tool_use/test_jamba_tool_parser.py +++ b/tests/tool_use/test_jamba_tool_parser.py @@ -12,7 +12,7 @@ from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers import JambaToolParser -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer MODEL = "ai21labs/Jamba-tiny-dev" diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py index ccb2acf512ca..f06fb2b9f2f0 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -13,7 +13,7 @@ ToolCall) from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( Qwen3CoderToolParser) -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" diff --git a/tests/tool_use/test_seed_oss_tool_parser.py b/tests/tool_use/test_seed_oss_tool_parser.py index c276a598aa68..118c7534622e 100644 --- a/tests/tool_use/test_seed_oss_tool_parser.py +++ b/tests/tool_use/test_seed_oss_tool_parser.py @@ -13,7 +13,7 @@ DeltaMessage, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer # Use a common model that is likely to be available diff --git a/tests/tool_use/test_xlam_tool_parser.py b/tests/tool_use/test_xlam_tool_parser.py index 0bc22e4f1031..c07ca0f56d6b 100644 --- a/tests/tool_use/test_xlam_tool_parser.py +++ b/tests/tool_use/test_xlam_tool_parser.py @@ -11,7 +11,7 @@ DeltaMessage, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers import xLAMToolParser -from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer # Use a common model that is likely to be available diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index a9632ce54eac..bdb40be99aa3 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -12,9 +12,9 @@ STOP_STRINGS, DummyOutputProcessorTestVectors, MockEngineCore) +from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.sequence import PromptLogprobs, SampleLogprobs from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.output_processor import (OutputProcessor, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index b75b94ad0acc..fd4b992c3821 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -15,10 +15,10 @@ from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.tasks import SupportedTask from vllm.utils import make_async +from vllm.v1.outputs import SamplerOutput from vllm.worker.worker_base import WorkerBase logger = init_logger(__name__) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 78d0ee6c1e3f..84747575b496 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -17,12 +17,12 @@ from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster, ray) from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, make_async) +from vllm.v1.outputs import SamplerOutput if ray is not None: from ray.actor import ActorHandle diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index e9db2a0dc13a..46f49aaa013d 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -7,15 +7,7 @@ SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) -from .registry import (DummyData, InputContext, InputProcessingContext, - InputRegistry) - -INPUT_REGISTRY = InputRegistry() -""" -The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used -by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the -target model. -""" +from .registry import InputContext, InputProcessingContext __all__ = [ "DataPrompt", @@ -36,9 +28,6 @@ "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", - "INPUT_REGISTRY", - "DummyData", "InputContext", "InputProcessingContext", - "InputRegistry", ] diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index f0b392e9767a..b5316b6d0574 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, Union import torch from transformers import BatchFeature, PretrainedConfig, ProcessorMixin @@ -15,16 +15,9 @@ if TYPE_CHECKING: from vllm.config import ModelConfig - from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict, - MultiModalRegistry) - from vllm.sequence import SequenceData from vllm.transformers_utils.tokenizer import AnyTokenizer else: ModelConfig = Any - MultiModalDataDict = Any - MultiModalPlaceholderDict = Any - MultiModalRegistry = Any - SequenceData = Any AnyTokenizer = Any _T = TypeVar("_T") @@ -191,61 +184,3 @@ def maybe_cast_dtype(x): f"on data={data} with kwargs={allowed_kwargs}") raise ValueError(msg) from exc - - -class DummyData(NamedTuple): - """ - Dummy data used for profiling. - - Note: This is only used in V0. - """ - - seq_data: SequenceData - multi_modal_data: Optional[MultiModalDataDict] = None - multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None - - -class InputRegistry: - """ - Note: This is only used in V0. - """ - - def dummy_data_for_profiling( - self, - model_config: ModelConfig, - seq_len: int, - mm_registry: MultiModalRegistry, - is_encoder_data: bool = False, - ) -> DummyData: - """ - Create dummy data for profiling the memory usage of a model. - - The model is identified by ``model_config``. - """ - # Avoid circular import - from vllm.multimodal.cache import processor_only_cache_from_config - from vllm.sequence import SequenceData - - if not model_config.is_multimodal_model: - seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) - return DummyData(seq_data=seq_data) - - cache = processor_only_cache_from_config(model_config, mm_registry) - - # Encoder dummy data does not contain multi-modal data - if is_encoder_data: - enc_data = mm_registry.get_encoder_dummy_data(model_config, - seq_len, - cache=cache) - seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) - return DummyData(seq_data=seq_data) - - dec_data = mm_registry.get_decoder_dummy_data(model_config, - seq_len, - cache=cache) - - return DummyData( - seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), - multi_modal_data=dec_data.multi_modal_data.get_data(), - multi_modal_placeholders=dec_data.multi_modal_placeholders, - ) diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 55dfe8088c8f..a59aebfac4ff 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -3,13 +3,11 @@ from vllm.model_executor.parameter import (BasevLLMParameter, PackedvLLMParameter) -from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingMetadataCache) +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed __all__ = [ "SamplingMetadata", - "SamplingMetadataCache", "set_random_seed", "BasevLLMParameter", "PackedvLLMParameter", diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 8a4ac214443e..8226437cb189 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that compute logits from hidden_stats.""" -import inspect -from concurrent.futures import ThreadPoolExecutor from typing import Optional import torch -import vllm.envs as envs from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.custom_op import CustomOp @@ -16,11 +13,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform -_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None -if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: - _logits_processor_threadpool = ThreadPoolExecutor( - envs.VLLM_LOGITS_PROCESSOR_THREADS) - @CustomOp.register("logits_processor") class LogitsProcessor(CustomOp): @@ -60,15 +52,10 @@ def forward( hidden_states: torch.Tensor, sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[torch.Tensor] = None, - prune_hidden_states: bool = True, ) -> Optional[torch.Tensor]: if self.logits_as_input: logits = hidden_states else: - if sampling_metadata is not None and prune_hidden_states: - hidden_states = _prune_hidden_states(hidden_states, - sampling_metadata) - # Get the logits for the next tokens. logits = self._get_logits(hidden_states, lm_head, embedding_bias) if logits is not None: @@ -79,12 +66,6 @@ def forward( if self.scale != 1.0: logits *= self.scale - - # Apply logits processors (if any). - if sampling_metadata is not None and \ - sampling_metadata.seq_groups is not None: - logits = _apply_logits_processors(logits, sampling_metadata) - return logits def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor: @@ -125,75 +106,3 @@ def extra_repr(self) -> str: s += f", org_vocab_size={self.org_vocab_size}" s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" return s - - -def _prune_hidden_states( - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - # NOTE(kzawora): The if guard is needed for Gaudi - in some scenarios - # (warmup, profile_run) we might not have selected_token_indices, - # so we skip pruning. - if sampling_metadata.selected_token_indices is not None: - return hidden_states.index_select( - 0, sampling_metadata.selected_token_indices) - else: - return hidden_states - - -def _apply_logits_processors( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - found_logits_processors = False - logits_processed = 0 - logits_row_ids_and_logits_row_futures = [] - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - - for seq_id, logits_row_idx in zip(seq_ids, - seq_group.sample_indices): - logits_row = logits[logits_row_idx] - past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids - prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids - - if _logits_processor_threadpool is not None: - logits_row_ids_and_logits_row_futures.append( - (logits_row_idx, - _logits_processor_threadpool.submit( - _apply_logits_processors_single_seq, logits_row, - logits_processors, past_tokens_ids, - prompt_tokens_ids))) - else: - logits[logits_row_idx] = \ - _apply_logits_processors_single_seq( - logits_row, logits_processors, past_tokens_ids, - prompt_tokens_ids) - - logits_processed += len(seq_group.sample_indices) + len( - seq_group.prompt_logprob_indices) - - for logits_row_idx, future in logits_row_ids_and_logits_row_futures: - logits[logits_row_idx] = future.result() - - if found_logits_processors: - # verifies that no rows in logits were missed unexpectedly - assert logits_processed == logits.shape[0] - return logits - - -def _apply_logits_processors_single_seq(logits_row, logits_processors, - past_tokens_ids, - prompt_tokens_ids) -> torch.Tensor: - for logits_processor in logits_processors: - parameters = inspect.signature(logits_processor).parameters - if len(parameters) == 3: - logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids, - logits_row) - else: - logits_row = logits_processor(past_tokens_ids, logits_row) - return logits_row diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py deleted file mode 100644 index 9d93cad2420a..000000000000 --- a/vllm/model_executor/layers/sampler.py +++ /dev/null @@ -1,1198 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A layer that samples the next tokens from the model's outputs.""" -import itertools -from collections.abc import Iterator -from dataclasses import dataclass -from importlib.util import find_spec -from math import inf -from typing import Optional, Union - -import msgspec -import torch -import torch.nn as nn - -import vllm.envs as envs -from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs -from vllm.model_executor.layers.utils import apply_penalties -from vllm.model_executor.sampling_metadata import (SamplingMetadata, - SamplingTensors, - SequenceGroupToSample) -from vllm.sampling_params import SamplingType -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, - CompletionSequenceGroupOutput, SequenceOutput) - -if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): - # yapf: disable - from flashinfer.sampling import ( - top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) - - # yapf: enable -else: - flashinfer_top_k_top_p_sampling = None - -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -def get_sampler() -> torch.nn.Module: - if envs.VLLM_USE_V1: - # Lazy import: the v1 package isn't distributed - from vllm.v1.sample.sampler import Sampler as V1Sampler - return V1Sampler() - return Sampler() - - -# (num_token_ids, num_parent_ids) per sequence group. -SampleResultType = list[tuple[list[int], list[int]]] - -# Types of temporary data structures used for -# computing sample_result -SampleMetadataType = dict[SamplingType, tuple[list[int], - list[SequenceGroupToSample]]] -MultinomialSamplesType = dict[SamplingType, torch.Tensor] -SampleResultsDictType = dict[int, tuple[list[int], list[int]]] - - -# Encapsulates temporary data structures for computing -# sample_result. -# -# * For multi-step scheduling: must be returned -# by `Sampler.forward()` and used later to compute the pythonized -# sample_result -# -# * For single-step scheduling: consumed immediately -# inside `Sampler.forward()` to compute pythonized sample_result. -@dataclass -class SampleResultArgsType: - sample_metadata: SampleMetadataType - multinomial_samples: MultinomialSamplesType - sample_results_dict: SampleResultsDictType - sampling_metadata: SamplingMetadata - greedy_samples: Optional[torch.Tensor] - - -# Union of non-deferred (single-step scheduling) -# vs deferred (multi-step scheduling) -# sample result types -MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] - -# Abbreviation of the _sample() return type -SampleReturnType = tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] - - -class SamplerOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """For each sequence group, we generate a list of SequenceOutput object, - each of which contains one possible candidate for the next token. - - This data structure implements methods, so it can be used like a list, but - also has optional fields for device tensors. - """ - - outputs: list[CompletionSequenceGroupOutput] - - # On-device tensor containing probabilities of each token. - sampled_token_probs: Optional[torch.Tensor] = None - - # On-device tensor containing the logprobs of each token. - logprobs: Optional["torch.Tensor"] = None - - # Holds either (1) the pythonized sampler result (single-step scheduling) - # or (2) what will be arguments for later deferred pythonization of the - # sampler result (muliti-step scheduling) - deferred_sample_results_args: Optional[SampleResultArgsType] = None - - # On-device tensor containing the sampled token ids. - sampled_token_ids: Optional[torch.Tensor] = None - # CPU tensor containing the sampled token ids. Used during multi-step to - # return the sampled token ids from last rank to AsyncLLMEngine to be - # 'broadcasted' to all other PP ranks for next step. - sampled_token_ids_cpu: Optional[torch.Tensor] = None - - # On-device tensor containing the sampled token embeddings (embeddings - # corresponding to the sampled token ids). Used when prompt embeddings are - # specified in lieu of prompt token ids or text. - sampled_token_embeds: Optional[torch.Tensor] = None - - # Optional last hidden states from the model. - hidden_states: Optional[torch.Tensor] = None - - # Optional prefill hidden states from the model - # (used for models like EAGLE). - prefill_hidden_states: Optional[torch.Tensor] = None - - # Time taken in the forward pass for this across all workers - model_forward_time: Optional[float] = None - - # Time taken in the model execute function. This will include model forward, - # block/sync across workers, cpu-gpu sync time and sampling time. - model_execute_time: Optional[float] = None - - def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput: - return self.outputs[idx] - - def __setitem__(self, idx: int, value): - self.outputs[idx] = value - - def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]: - return iter(self.outputs) - - def __len__(self): - return len(self.outputs) - - def __eq__(self, other: object): - return isinstance(other, - self.__class__) and self.outputs == other.outputs - - def __repr__(self) -> str: - """Show the shape of a tensor instead of its values to reduce noise. - """ - sampled_token_probs_repr = ("None" if self.sampled_token_probs is None - else self.sampled_token_probs.shape) - sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else - self.sampled_token_ids.shape) - return (f"SamplerOutput(outputs={self.outputs}, " - f"sampled_token_probs={sampled_token_probs_repr}, " - f"sampled_token_ids={sampled_token_ids_repr})") - - -class Sampler(nn.Module): - """Samples the next tokens from the model's outputs. - - This layer does the following: - 1. Discard the hidden states that are not used for sampling (i.e., all - tokens except the final one in each prompt). - 2. Compute the logits for the next tokens. - 3. Apply presence, frequency and repetition penalties. - 4. Apply temperature scaling. - 5. Apply top-p and top-k truncation. - 6. Sample the next tokens. - Here, each sequence group within the batch can have different sampling - parameters (e.g., sampling method, temperature, top-p, top-k, etc.). - - The structure of the logits tensor is coupled with the seq_groups in - sampling_metadata. Typically, each sequence in each seq_group has one row in - logits for the next token to be sampled; however, for a seq_group with a - prompt request with the prompt_logprobs sampling parameter, there are rows - in logits for each token in the input prompt. - """ - - def __init__(self): - super().__init__() - - # Whether or not the SamplerOutput should have on-device tensors - # containing the sampled token ids and probabilities. This is used by - # speculative decoding and when prompt embeddings are specified. - self.include_gpu_probs_tensor = False - self.should_modify_greedy_probs_inplace = False - - def _init_sampling_tensors( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ): - """The goal here is to reuse sampling tensors between similar decode - runs. This is possible because sampling logic does not change between - decodes of the same sequences. - """ - _, vocab_size = logits.shape - - # First free any existing stored sampling tensors. - # This is necessary because some sampling tensors may - # have pinned memory. - self._sampling_tensors = None - - # Initialize new sampling tensors - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) - - self._sampling_tensors = sampling_tensors - self._do_penalties = do_penalties - self._do_top_p_top_k = do_top_p_top_k - self._do_min_p = do_min_p - - def forward( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - """ - Single-step scheduling: - * Perform GPU-side sampling computation & compute - GPU-side logprobs tensor - * Pythonize sampling result & logprobs tensor - - Multi-step scheduling: - * Perform GPU-side sampling computation & compute - GPU-side logprobs tensor - * Defer Pythonization of sampling result & logprobs - tensor - * Encapsulate arguments required for deferred Pythonization - in the - [`SamplerOutput`][vllm.model_executor.layers.sampler.SamplerOutput] - structure - - Args: - logits: (num_tokens, vocab_size). - sampling_metadata: Metadata for sampling. - """ - assert logits is not None - _, vocab_size = logits.shape - - # Prepare sampling tensors with pinned memory to avoid blocking. - if not sampling_metadata.reuse_sampling_tensors: - self._init_sampling_tensors(logits, sampling_metadata) - elif self._do_penalties: - # In this case, the sampling tensors logic depends on - # "output_tokens" of a sequence. As a result, we cannot - # reuse sampling tensors, since "output_tokens" changes - # between decode runs. - self._init_sampling_tensors(logits, sampling_metadata) - - assert self._sampling_tensors is not None - sampling_tensors = self._sampling_tensors - do_penalties = self._do_penalties - do_top_p_top_k = self._do_top_p_top_k - do_min_p = self._do_min_p - - logits = _apply_min_tokens_penalty(logits, sampling_metadata) - - # Apply presence and frequency penalties. - if do_penalties: - logits = apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) - - # Use float32 to apply temperature scaling. - # Use in-place division to avoid creating a new tensor. - logits = logits.to(torch.float) - logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) - - if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) - - if do_min_p: - logits = _apply_min_p(logits, sampling_tensors.min_ps) - - # We use float32 for probabilities and log probabilities. - # Compute the probabilities. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities. - logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - - # Sample the next tokens. - maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( - probs, - logprobs, - sampling_metadata, - sampling_tensors, - include_gpu_probs_tensor=self.include_gpu_probs_tensor, - modify_greedy_probs=self._should_modify_greedy_probs_inplace, - ) - - if self.include_gpu_probs_tensor: - # Since we will defer sampler result Pythonization, - # preserve GPU-side tensors in support of later - # deferred pythonization of logprobs - assert maybe_sampled_tokens_tensor is not None - on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) - else: - # Since Pythonization has already happened, don't preserve - # GPU-side tensors. - on_device_tensors = None - - # Get the logprobs query results. - prompt_logprobs = None - sample_logprobs = None - if not sampling_metadata.skip_sampler_cpu_output: - # Pythonize logprobs now (GPU -> CPU); do not defer. - assert not isinstance(maybe_deferred_sample_results, - SampleResultArgsType) - prompt_logprobs, sample_logprobs = get_logprobs( - logprobs, sampling_metadata, maybe_deferred_sample_results) - - return _build_sampler_output( - maybe_deferred_sample_results, - sampling_metadata, - prompt_logprobs, - sample_logprobs, - on_device_tensors=on_device_tensors, - skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) - - @property - def _should_modify_greedy_probs_inplace(self) -> bool: - """Whether or not the sampler should modify the probability distribution - of greedily-sampled tokens such that multinomial sampling would sample - the greedily-sampled token. - - In other words, if True then we set the probability of the greedily- - sampled token to 1. - - This is used by speculative decoding, which requires that the sampling - method be encoded into the probability distribution. - """ - return self.should_modify_greedy_probs_inplace - - -def _apply_min_tokens_penalty( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens - have not been generated yet - """ - # list of indices in logits that will be set to -inf - logits_to_penalize: list[tuple[int, int]] = [] - logits_applied = 0 - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - - sample_indices = seq_group.sample_indices - logits_applied += len(sample_indices) + len( - seq_group.prompt_logprob_indices) - if not seq_group.do_sample: - continue - - start_idx = sample_indices[0] - min_tokens = sampling_params.min_tokens - token_ids_to_penalize = sampling_params.all_stop_token_ids - if min_tokens > 0 and token_ids_to_penalize: - seqs_to_penalize: list[int] = [] - for j, seq_id in enumerate(seq_ids): - seq_data = seq_group.seq_data[seq_id] - if len(seq_data.output_token_ids_array) < min_tokens: - seqs_to_penalize.append(j) - - if seqs_to_penalize: - # convert to the index into logits - seqs_to_penalize = [start_idx + j for j in seqs_to_penalize] - # itertools.product pairs each seq index with every token id - logits_to_penalize.extend( - itertools.product(seqs_to_penalize, token_ids_to_penalize)) - - if logits_to_penalize: - # use zip and * to group indices along each dimension - # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) - logits[tuple(zip(*logits_to_penalize))] = -float("inf") - - # verifies that no rows in logits were missed unexpectedly - assert logits_applied == logits.shape[0] - return logits - - -def _apply_top_k_top_p( - logits: torch.Tensor, - p: torch.Tensor, - k: torch.Tensor, -) -> torch.Tensor: - logits_sort, logits_idx = logits.sort(dim=-1, descending=False) - - # Apply top-k. - top_k_mask = logits_sort.size(1) - k.to(torch.long) - # Get all the top_k values. - top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) - top_k_mask = logits_sort < top_k_mask - logits_sort.masked_fill_(top_k_mask, -float("inf")) - - # Apply top-p. - probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) - # at least one - top_p_mask[:, -1] = False - logits_sort.masked_fill_(top_p_mask, -float("inf")) - - # Re-sort the probabilities. - logits = torch.empty_like(logits_sort).scatter_(dim=-1, - index=logits_idx, - src=logits_sort) - return logits - - -def _apply_min_p( - logits: torch.Tensor, - min_p: torch.Tensor, -) -> torch.Tensor: - """ - Adapted from - https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 - """ - probs = torch.softmax(logits, dim=-1) - top_probs, _ = probs.max(dim=-1, keepdim=True) - scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs - tokens_to_remove = probs < scaled_min_p - logits = logits.masked_fill_(tokens_to_remove, -float("inf")) - - return logits - - -def _greedy_sample( - selected_seq_groups: list[SequenceGroupToSample], - samples: torch.Tensor, -) -> SampleResultType: - """Run greedy sampling on a given samples. - - Args: - selected_seq_groups: A list of sequence groups batched. - samples: (num_selected_samples,) A tensor of samples. The length of - samples could be smaller than selected_seq_groups if - seq_group.do_sample is False. - Returns: - Tuple of (next_token_ids, parent_ids). The length of returned list is - same as the length of selected_seq_groups. If the corresponding - seq_group has do_sample=False, tuple contains ([], []) - """ - samples_lst = samples.tolist() - sample_idx = 0 - results: SampleResultType = [] - for seq_group in selected_seq_groups: - if not seq_group.do_sample: - results.append(([], [])) - continue - - seq_ids = seq_group.seq_ids - num_parent_seqs = len(seq_ids) - assert num_parent_seqs == 1, ( - "Greedy sampling should have only one seq.") - parent_ids = list(range(num_parent_seqs)) - next_token_ids = [samples_lst[sample_idx]] - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - return results - - -def _random_sample( - selected_seq_groups: list[SequenceGroupToSample], - random_samples: torch.Tensor, -) -> SampleResultType: - """Run random sampling on a given samples. - - Args: - selected_seq_groups: A list of sequence groups batched. - random_samples: (num_selected_samples,) A tensor of samples. The - length of samples could be smaller than selected_seq_groups if - seq_group.do_sample is False. - Returns: - Tuple of (next_token_ids, parent_ids). The length of returned list is - same as the length of selected_seq_groups. If the corresponding - seq_group has do_sample=False, tuple contains ([], []) - """ - # Find the maximum n value of the prompt phase requests. - random_samples = random_samples.cpu() - sample_idx = 0 - results: SampleResultType = [] - for seq_group in selected_seq_groups: - if not seq_group.do_sample: - results.append(([], [])) - continue - - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - is_prompt = seq_group.is_prompt - num_parent_seqs = len(seq_ids) - if is_prompt: - # Prompt phase. - parent_ids = [0] * sampling_params.n - next_token_ids = random_samples[ - sample_idx, :sampling_params.n].tolist() - else: - # Generation phase. - parent_ids = list(range(num_parent_seqs)) - next_token_ids = random_samples[sample_idx:sample_idx + - num_parent_seqs, 0].tolist() - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - return results - - -# torch.multinomial forces a GPU<->CPU sync. -# Therefore, we use an optimized implementation instead. -# Note that we always sample with replacement. -# probs will be modified in place, but this is fine, as we pass -# in a copy already. -def _multinomial( - probs: torch.Tensor, - num_samples: int, - seq_groups: Optional[list[SequenceGroupToSample]] = None, -) -> torch.Tensor: - if num_samples > 1: - probs = probs.repeat_interleave(num_samples, dim=0) - q = torch.empty_like(probs) - if seq_groups is None: - q.exponential_() - else: - sample_idx = 0 - for seq_group in seq_groups: - seq_ids = seq_group.seq_ids - stride = len(seq_ids) * num_samples - assert seq_group.generator is not None - q[sample_idx:sample_idx + - stride].exponential_(generator=seq_group.generator) - sample_idx += stride - return probs.div_(q).argmax(dim=1).view(-1, num_samples) - - -def _top_k_top_p_multinomial_with_flashinfer( - probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, - num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]): - if num_samples > 1: - probs = probs.repeat_interleave(num_samples, dim=0) - top_ks = top_ks.repeat_interleave(num_samples) - top_ps = top_ps.repeat_interleave(num_samples) - batch_next_token_ids = flashinfer_top_k_top_p_sampling( - probs, - top_ks, - top_ps, - ) - return batch_next_token_ids.view(-1, num_samples) - - -def get_pythonized_sample_results( - sample_result_args: SampleResultArgsType) -> SampleResultType: - '''This function consumes GPU-side sampler results and computes - Pythonized CPU-side sampler results (GPU -> CPU sync.) - - Single-step scheduling: this function is invoked at sampling-time - for immediate Pythonization. - - Multi-step scheduling: Pythonization is deferred until after multiple - GPU-side steps have been completed. - - Args: - sample_result_args: GPU-side inputs to the Pythonization process - - Returns: - Pythonized sampler results - ''' - - ( - sample_metadata, - sampling_metadata, - greedy_samples, - multinomial_samples, - sample_results_dict, - ) = ( - sample_result_args.sample_metadata, - sample_result_args.sampling_metadata, - sample_result_args.greedy_samples, - sample_result_args.multinomial_samples, - sample_result_args.sample_results_dict, - ) - - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, - multinomial_samples[sampling_type]) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - return [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] - - -def _sample_with_torch( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, -) -> SampleReturnType: - '''Torch-oriented _sample() implementation. - - Single-step scheduling: - * Perform GPU-side sampling computation - * Immediately Pythonize sampling result - - Multi-step scheduling: - * Perform GPU-side sampling computation - * Defer Pythonization & preserve GPU-side - tensors required for Pythonization - ''' - - categorized_seq_group_ids: dict[SamplingType, list[int]] = { - t: [] - for t in SamplingType - } - categorized_sample_indices = sampling_metadata.categorized_sample_indices - for i, seq_group in enumerate(sampling_metadata.seq_groups): - sampling_params = seq_group.sampling_params - sampling_type = sampling_params.sampling_type - categorized_seq_group_ids[sampling_type].append(i) - - sample_results_dict: SampleResultsDictType = {} - sample_metadata: SampleMetadataType = {} - multinomial_samples: MultinomialSamplesType = {} - greedy_samples: Optional[torch.Tensor] = None - - # Create output tensor for sampled token ids. - if include_gpu_probs_tensor: - sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), - VLLM_INVALID_TOKEN_ID, - dtype=torch.long, - device=logprobs.device) - else: - sampled_token_ids_tensor = None - - # Counterintuitively, having two loops here is actually faster. - # The first loop can run without waiting on GPU<->CPU sync. - for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type] - num_tokens = len(sample_indices) - if num_tokens == 0: - continue - - seq_group_id = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] - sample_metadata[sampling_type] = (seq_group_id, seq_groups) - long_sample_indices = sample_indices.long() - if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[long_sample_indices], - dim=-1) - - if sampled_token_ids_tensor is not None: - # Store sampled tokens in output tensor. - sampled_token_ids_tensor[ - long_sample_indices] = greedy_samples.unsqueeze(-1) - - if modify_greedy_probs: - # If required, modify the probabilities such that sampling from - # the modified distribution would always sample the argmax - # token id. - _modify_greedy_probs_inplace(logprobs, probs, - long_sample_indices, - greedy_samples) - - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - max_n_in_batch = 1 - for seq_group in seq_groups: - if seq_group.is_prompt: - sampling_params = seq_group.sampling_params - max_n_in_batch = max(max_n_in_batch, sampling_params.n) - seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else - seq_groups) - - if flashinfer_top_k_top_p_sampling is not None: - logger.warning("FlashInfer 0.2.3+ does not support " - "per-request generators. Falling back to " - "PyTorch-native implementation.") - - multinomial_samples[sampling_type] = _multinomial( - probs[long_sample_indices], - max_n_in_batch, - seq_groups=seq_groups_arg) - - if sampled_token_ids_tensor is not None: - # Store sampled tokens in output tensor. - sampled_token_ids_tensor[long_sample_indices] = \ - multinomial_samples[sampling_type].to(torch.long) - - else: - raise ValueError(f"Unsupported sampling type: {sampling_type}") - - # Encapsulate arguments for computing Pythonized sampler - # results, whether deferred or otherwise. - maybe_deferred_args = SampleResultArgsType( - sampling_metadata=sampling_metadata, - sample_metadata=sample_metadata, - multinomial_samples=multinomial_samples, - greedy_samples=greedy_samples, - sample_results_dict=sample_results_dict) - - if not sampling_metadata.skip_sampler_cpu_output: - # GPU<->CPU sync happens here. - # This also converts the sampler output to a Python object. - # Return Pythonized sampler result & sampled token ids - return get_pythonized_sample_results( - maybe_deferred_args), sampled_token_ids_tensor - else: - # Defer sampler result Pythonization; return deferred - # Pythonization args & sampled token ids - return ( - maybe_deferred_args, - sampled_token_ids_tensor, - ) - - -def _sample( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, - include_gpu_probs_tensor: bool, - modify_greedy_probs: bool, -) -> SampleReturnType: - """ - Args: - probs: (num_query_tokens_in_batch, num_vocab) - logprobs: (num_query_tokens_in_batch, num_vocab) - sampling_metadata: The metadata for a batch for sampling. - sampling_tensors: Tensors that include sampling related metadata. - - Returns: - (next_token_ids, parent_seq_ids) for each seq group in a batch. - If sampling is skipped, it returns ([], []) - sampled_token_ids_tensor: A tensor of sampled token ids. - """ - return _sample_with_torch( - probs, - logprobs, - sampling_metadata, - sampling_tensors, - include_gpu_probs_tensor=include_gpu_probs_tensor, - modify_greedy_probs=modify_greedy_probs, - ) - - -def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - """ - This function calculates the ranks of the chosen tokens in a logprob tensor. - - Args: - x (torch.Tensor): 2D logprob tensor of shape (N, M) - where N is the no. of tokens and M is the vocab dim. - indices (torch.Tensor): List of chosen token indices. - - Returns: - torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. - Each element in the returned tensor represents the rank - of the chosen token in the input logprob tensor. - """ - vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), - indices] - result = (x > vals[:, None]) - del vals - return result.sum(1).add_(1) - - -def get_logprobs( - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sample_results: SampleResultType, -) -> tuple[list[Optional[PromptLogprobs]], list[SampleLogprobs]]: - """Return sample logprobs and prompt logprobs. - - The logic consists of 3 parts. - - Select indices to compute logprob from, ranks of token ids, and - the top k token ids from logprobs. - - Compute prompt logprobs if required. - - Compute sample logprobs if required. - - Args: - logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's - logprob per vocab. Sequence groups' query tokens are batched in a - single flattened tensor. For example, assuming there are N - seq groups, it is sorted by prefill tokens for seq_group_1 (if - prompt logprob is enabled), decode tokens for seq_group_1 (if - sampling is required), prefill tokens for seq_group_2, ... - sampling_metadata: The sampling metadata. - sample_results: (num_seq_groups) The tuple of (next_token_ids, - parent_ids) for each sequence group. When beam search is enabled, - sample_results can contain different number of seq_ids from - sampling_metadata.seq_groups. It is because beam search creates - 2 * BEAM_WIDTH number of samples (whereas there are only up to - BEAM_WIDTH number of seq_ids). - - Returns: - A tuple of prompt and sample logprobs per sequence group in a batch. - """ - # The index of query token to calculate logprobs. It includes both - # prompt and sample logprob indices. - query_indices: list[int] = [] - # The next token ids to get the logprob value from. - next_token_ids: list[int] = [] - # The largest requested number of logprobs. We find logprobs as many as the - # largest num logprobs in this API. If every logprobs is None, it will be - # set to -1. - largest_num_logprobs = -1 - - # Select indices to compute logprob from, ranks of token ids, and the top - # k token ids from logprobs. - for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, - sample_results): - sampling_params = seq_group.sampling_params - - # Update indices and tokens for prompt logprobs. - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.prompt_logprobs) - next_prompt_tokens = _get_next_prompt_tokens(seq_group) - query_indices.extend(seq_group.prompt_logprob_indices) - next_token_ids.extend(next_prompt_tokens) - - # Update indices and next tokenes for sample logprob. - if seq_group.do_sample: - token_ids, parent_seq_ids = sample_result - # NOTE: We cannot directly use sample_indices because - # sample_indices only contain parent seq_ids of a previous step. - # The current step may have different number of seq_ids, and - # we can obtain it from `sample_result[1]`. - query_idx = seq_group.sample_indices[0] - query_indices.extend( - [query_idx + parent_id for parent_id in parent_seq_ids]) - next_token_ids.extend(token_ids) - - if sampling_params.logprobs is not None: - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.logprobs) - - assert len(next_token_ids) == len(query_indices) - - if len(query_indices) == 0: - empty_sampled_logprob: SampleLogprobs = [] - empty_prompt_logprob: Optional[PromptLogprobs] = None - num_seq_groups = len(sampling_metadata.seq_groups) - return [empty_prompt_logprob - ] * num_seq_groups, [empty_sampled_logprob] * num_seq_groups - - selected_logprobs, ranks = None, None - top_logprobs, top_token_ids = None, None - - # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can - # skip the whole logprob calculation. - if largest_num_logprobs >= 0: - query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) - next_token_ids_gpu = torch.tensor(next_token_ids, - device=logprobs.device) - - # (num_selected_query_tokens, num_logprobs). Note that query_indices can - # contain duplicates if beam search is enabled. - selected_logprobs = logprobs[[ - query_indices_gpu, - next_token_ids_gpu, - ]] - ranks = _get_ranks( - logprobs[query_indices_gpu], - next_token_ids_gpu, - ) - assert selected_logprobs.shape[0] == ranks.shape[0] - - # We need to compute top k only if there exists logprobs > 0. - if largest_num_logprobs > 0: - # Logprobs of topk tokens for a batch of sequence groups. - # (num_query_tokens_across_batch). - top_logprobs, top_token_ids = torch.topk(logprobs, - largest_num_logprobs, - dim=-1) - top_logprobs = top_logprobs.to('cpu') - top_token_ids = top_token_ids.to('cpu') - - selected_logprobs = selected_logprobs.to('cpu') - ranks = ranks.to('cpu') - - # Find prompt/sample logprobs. - prompt_logprobs_per_seq_group: list[Optional[PromptLogprobs]] = [] - sample_logprobs_per_seq_group: list[SampleLogprobs] = [] - top_logprob_idx = 0 - selected_logprobs_idx = 0 - - for seq_group, sample_result in zip(sampling_metadata.seq_groups, - sample_results): - (prompt_logprobs, top_logprob_idx, - selected_logprobs_idx) = _get_prompt_logprob_if_needed( - seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs, - selected_logprobs_idx, top_logprob_idx) - prompt_logprobs_per_seq_group.append(prompt_logprobs) - - (sampled_logprobs, top_logprob_idx, - selected_logprobs_idx) = _get_sampled_logprob_if_needed( - seq_group, sample_result, selected_logprobs, ranks, top_token_ids, - top_logprobs, selected_logprobs_idx, top_logprob_idx) - sample_logprobs_per_seq_group.append(sampled_logprobs) - - return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group - - -def _get_prompt_logprob_if_needed( - seq_group: SequenceGroupToSample, - selected_logprobs: torch.Tensor, - ranks: torch.Tensor, - top_token_ids: torch.Tensor, - top_logprobs: torch.Tensor, - selected_logprobs_idx: int, - top_logprob_idx: int, -): - """Compute the prompt logprob from a sequence group if needed.""" - sampling_params = seq_group.sampling_params - is_prompt = seq_group.is_prompt - - # Find prompt logprobs - prompt_logprobs: Optional[PromptLogprobs] = None - if is_prompt and sampling_params.prompt_logprobs is not None: - prompt_logprobs = [] - num_logprobs = sampling_params.prompt_logprobs - next_prompt_tokens = _get_next_prompt_tokens(seq_group) - # Pre-select indexes and create a list. It is faster than calling .item - # repetitively. - selected_logprob_items = selected_logprobs[ - selected_logprobs_idx:selected_logprobs_idx + - len(next_prompt_tokens)].tolist() - rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + - len(next_prompt_tokens)].tolist() - - for idx, token_id in enumerate(next_prompt_tokens): - # Calculate the prompt logprob of the real prompt tokens. - # {token_id: (logprob, rank_from_vocab)} - prompt_logprobs_dict: dict[int, tuple[float, int]] = { - token_id: (selected_logprob_items[idx], rank_items[idx]) - } - - # Add top K prompt logprobs along with its rank. - if num_logprobs > 0: - top_ids = top_token_ids[ - top_logprob_idx, :num_logprobs].tolist() - top_probs = top_logprobs[ - top_logprob_idx, :num_logprobs].tolist() - # Top K is already sorted by rank, so we can use 1 ~ - # num_logprobs + 1 for rank. - top_ranks = range(1, num_logprobs + 1) - prompt_logprobs_dict.update({ - top_id: (top_prob, rank) - for top_id, top_prob, rank in zip(top_ids, top_probs, - top_ranks) - }) - prompt_logprobs.append({ - token_id: Logprob(*logprob_and_rank) - for token_id, logprob_and_rank in prompt_logprobs_dict.items() - }) - # + 1 to go to the next prompt token. - top_logprob_idx += 1 - - # + len(next_prompt_tokens) to go to the next prompt. - selected_logprobs_idx += len(next_prompt_tokens) - return prompt_logprobs, top_logprob_idx, selected_logprobs_idx - - -def _get_sampled_logprob_if_needed( - seq_group: SequenceGroupToSample, - sample_result: tuple[list[int], list[int]], - selected_logprobs: torch.Tensor, - ranks: torch.Tensor, - top_token_ids: torch.Tensor, - top_logprobs: torch.Tensor, - selected_logprobs_idx: int, - top_logprob_idx: int, -): - """Compute the sample logprob if needed.""" - seq_ids = seq_group.seq_ids - num_logprobs = seq_group.sampling_params.logprobs - sampled_logprobs: SampleLogprobs = [] - next_token_ids, parent_seq_ids = sample_result - - if seq_group.do_sample: - assert len(next_token_ids) > 0 - if num_logprobs is None: - for next_token_id in next_token_ids: - # Use a dummy logprob - sampled_logprobs.append({next_token_id: Logprob(inf)}) - else: - # Pre-select items from tensor. tolist() is faster than repetitive - # `.item()` calls. - selected_logprob_items = selected_logprobs[ - selected_logprobs_idx:selected_logprobs_idx + - len(next_token_ids)].tolist() - rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + - len(next_token_ids)].tolist() - for idx, (next_token_id, parent_id) in enumerate( - zip(next_token_ids, parent_seq_ids)): - # Get the logprob of a sampled token. - sampled_logprobs_dict = { - next_token_id: - (selected_logprob_items[idx], rank_items[idx]) - } - if num_logprobs is not None and num_logprobs > 0: - # Get top K logprobs. - top_ids = top_token_ids[top_logprob_idx + - parent_id, :num_logprobs].tolist() - top_probs = top_logprobs[ - top_logprob_idx + parent_id, :num_logprobs].tolist() - # Top K is already sorted by rank, so we can use 1 ~ - # num_logprobs + 1 for rank. - top_ranks = range(1, num_logprobs + 1) - sampled_logprobs_dict.update({ - top_id: (top_prob, rank) - for top_id, top_prob, rank in zip( - top_ids, top_probs, top_ranks) - }) - - sampled_logprobs.append({ - token_id: Logprob(*logprob_and_rank) - for token_id, logprob_and_rank in - sampled_logprobs_dict.items() - }) - - # NOTE: This part of code is not intuitive. `selected_logprobs` include - # logprobs for the current step, which has len(next_token_ids) tokens - # per sequence group. `logprobs` includes logprobs from the previous - # steps, which has len(seq_ids) tokens per sequence group. - - # Iterate to the next sequence group in a batch. - selected_logprobs_idx += len(next_token_ids) - # Iterate to the next sequence group in a batch. - top_logprob_idx += len(seq_ids) - return sampled_logprobs, top_logprob_idx, selected_logprobs_idx - - -def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, - sample_indices: torch.Tensor, - greedy_samples: torch.Tensor) -> None: - """Modify the probability distributions of the greedily-sampled tokens such - that each sampled token has a "probability" of 1.0. This is required by - speculative decoding, which depends on the sampling method being encoded - within the probability distribution for correctness. - - # Why do we only need to do this for greedy sampling? - - vLLM's sampler performs the following steps for greedy or multinomial - (random) sampling: - 1. Get logits from model. - 2. Modify logits according to per-sequence sampling parameters. - - Multiply by temperature, top-k and top-p masking, penalize tokens - according to their frequency, etc. - 3. Sample a token. - - Random sampling simply samples from the modified probability - distribution. - - Greedy sampling performs `argmax` to obtain the token with the - highest likelihood. - - Ignoring greedy sampling for a moment, we find that the computed probability - distribution has the following property: we can sample from it independently - and find that the token sampled by the Sampler has a frequency corresponding - to how often we see it in our sampling. In other words, for tokens sampled - with vLLM's random SamplingType, the computed probability distribution - encodes the sampling methodology completely. - - Greedy sampling does not normally have this property. vLLM modifies logits - according to sampling params, then performs `argmax`, then returns the - sampled token and the computed probability distribution. If we sample from - the distribution, we'll find the likelihood of the greedily-sampled token - is not always 1.0. - - Since lossless speculative decoding requires that the sampling methodology - be encoded within the probability distribution, we are motivated to modify - the probability distribution such that the sampled token has probability 1 - when speculative decoding is used. - - NOTE: Alternatively, we could use an extremely low temperature to achieve - greedy sampling using multinomial computation and unite the codepaths. This - has implications on the overall design of the sampler, e.g. how to record - accurate logprobs for the user, so this improvement is deferred to later. - """ - # NOTE: logprobs are not modified so they can be returned to the user. - probs[sample_indices, :] = 0 - probs[sample_indices, greedy_samples] = 1.0 - - -def _build_sampler_output( - maybe_deferred_sample_results: MaybeDeferredSampleResultType, - sampling_metadata: SamplingMetadata, - prompt_logprobs: Optional[list[Optional[PromptLogprobs]]], - sample_logprobs: Optional[list[SampleLogprobs]], - on_device_tensors: Optional[tuple[torch.Tensor, torch.Tensor, - torch.Tensor]], - skip_sampler_cpu_output: bool = False, -) -> SamplerOutput: - """Construct Python objects with the output of sampling. - - Args: - on_device_tensors: Tuple containing on-device tensors with the - probabilities used in sampling and the sampled token ids. This - allows post-processing without copies to CPU/serialization, e.g. in - speculative decoding rejection sampling. - """ - sampler_output: list[CompletionSequenceGroupOutput] = [] - - if skip_sampler_cpu_output: - assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) - deferred_sample_results_args = maybe_deferred_sample_results - else: - assert prompt_logprobs is not None - assert sample_logprobs is not None - assert not isinstance(maybe_deferred_sample_results, - SampleResultArgsType) - assert len(sampling_metadata.seq_groups) \ - == len(maybe_deferred_sample_results) \ - == len(prompt_logprobs) \ - == len(sample_logprobs) - deferred_sample_results_args = None - - for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - maybe_deferred_sample_results, - prompt_logprobs, sample_logprobs): - seq_ids = seq_group.seq_ids - next_token_ids, parent_ids = sample_result - seq_outputs: list[SequenceOutput] = [] - for parent_id, next_token_id, logprobs in zip( - parent_ids, next_token_ids, group_sample_logprobs): - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, - logprobs)) - sampler_output.append( - CompletionSequenceGroupOutput(seq_outputs, - group_prompt_logprobs)) - - # If not specified, store None values in SamplerOutput. - if on_device_tensors is not None: - (sampled_token_probs, logprobs_tensor, - sampled_token_ids) = on_device_tensors - else: - sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, - None) - - return SamplerOutput( - outputs=sampler_output, - sampled_token_probs=sampled_token_probs, - sampled_token_ids=sampled_token_ids, - logprobs=logprobs_tensor, - deferred_sample_results_args=deferred_sample_results_args) - - -def _get_next_prompt_tokens( - seq_group: SequenceGroupToSample) -> tuple[int, ...]: - """Get a list of next prompt tokens to compute logprob from a - given sequence group. - - It is used to compute prompt logprob. Imagine you have logprob for each - query token. Query token needs to know the next prompt token id to compute - prompt logprob. This is a helper to obtain next prompt token ids. - - This API has to be used only when the caller knows seq_group is in prefill - stage. - - Returns: - A list of next prompt tokens to compute logprob. - """ - assert seq_group.is_prompt, ( - "Caller should ensure the sequence group is in a prefill stage.") - seq_ids = seq_group.seq_ids - query_len = seq_group.query_len - assert query_len is not None - # prompt has only 1 seq id. - assert len(seq_ids) == 1 - seq_data = seq_group.seq_data[seq_ids[0]] - computed_len = seq_data.get_num_computed_tokens() - prompt_tokens = seq_data.prompt_token_ids - # +1 because we are looking for a next prompt token. - next_token_index_start = computed_len + 1 - next_token_index_end = min(computed_len + query_len + 1, - len(prompt_tokens)) - next_prompt_tokens = prompt_tokens[ - next_token_index_start:next_token_index_end] - return next_prompt_tokens diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 6ba8ad372c95..b0a96fca2ff8 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -2,18 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Optional import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from .utils import maybe_prefix @@ -105,8 +102,10 @@ def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]: return [block(hidden_states) for block in self.blocks] def compute_logits( - self, hidden_states: list[torch.Tensor], - sampling_metadata: SamplingMetadata) -> list[torch.Tensor]: + self, + hidden_states: list[torch.Tensor], + sampling_metadata, + ) -> list[torch.Tensor]: logits_lst: list[torch.Tensor] = [] for hs, lm_head in zip(hidden_states, self.lm_heads): @@ -130,57 +129,6 @@ def compute_logits( return logits_lst - def sample( - self, - logits: list[torch.Tensor], - sampling_metadata: SamplingMetadata, - ) -> list[SamplerOutput]: - logits = torch.stack(logits, dim=0).float() - logprobs = torch.log_softmax(logits, dim=-1) - token_ids = logits.argmax(-1) # support only top-1 for now - probs = torch.softmax(logits, dim=-1) - - token_id_list = [] - token_prob_list = [] - token_logprob_list = [] - - for idx, seq_group in enumerate(sampling_metadata.seq_groups): - token_id_list.append(token_ids[:, seq_group.sample_indices]) - token_prob_list.append(probs[:, seq_group.sample_indices]) - token_logprob_list.append(logprobs[:, seq_group.sample_indices]) - - outputs: list[Optional[SamplerOutput]] = [] - for idx in range(len(sampling_metadata.seq_groups)): - outputs.append( - SamplerOutput( - outputs=None, - sampled_token_probs=token_prob_list[idx].squeeze(1), - logprobs=token_logprob_list[idx].squeeze(1), - sampled_token_ids=token_id_list[idx].squeeze(1), - )) - - return outputs - - def generate_proposals( - self, - previous_hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[list[SamplerOutput]]: - # During preemption, we may receive an empty tensor (batch_size=0) - if previous_hidden_states.size(0) == 0: - # Return None to signal the Top1Proposer that no proposals - # were generated for this batch, allowing it to handle this - # special case appropriately - return None - - return self.sample( - logits=self.compute_logits( - hidden_states=self.forward(previous_hidden_states), - sampling_metadata=sampling_metadata, - ), - sampling_metadata=sampling_metadata, - ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index c6a97388dc18..d057eb49a62d 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -8,9 +8,7 @@ import torch.nn as nn from vllm.config import VllmConfig -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -141,55 +139,57 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.config = config self.logits_processor = LogitsProcessor(config.vocab_size, config.vocab_size, 1.0) - self.sampler = get_sampler() - def generate_proposals( - self, - input_ids: torch.Tensor, - previous_hidden_states: torch.Tensor, - num_predict_tokens: int, - sampling_metadata: SamplingMetadata, - ) -> list[SamplerOutput]: - if num_predict_tokens > self.max_speculative_tokens: - raise ValueError(f"Max speculative tokens for model is " - f"{self.max_speculative_tokens}, but " - f"{num_predict_tokens} were requested") - - # b x 1 x d - previous_hidden_states = previous_hidden_states.unsqueeze(1) + # NOTE(woosuk): This method is commented out because it is old code + # using V0. We should either port it to V1 or remove it. - if self.scale_input: - previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 + # def generate_proposals( + # self, + # input_ids: torch.Tensor, + # previous_hidden_states: torch.Tensor, + # num_predict_tokens: int, + # sampling_metadata: SamplingMetadata, + # ) -> list[SamplerOutput]: + # if num_predict_tokens > self.max_speculative_tokens: + # raise ValueError(f"Max speculative tokens for model is " + # f"{self.max_speculative_tokens}, but " + # f"{num_predict_tokens} were requested") + + # # b x 1 x d + # previous_hidden_states = previous_hidden_states.unsqueeze(1) + + # if self.scale_input: + # previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 - # b x 1 - last_tokens = input_ids.unsqueeze(1) + # # b x 1 + # last_tokens = input_ids.unsqueeze(1) - next_tokens = [] + # next_tokens = [] - for head_index in range(num_predict_tokens): + # for head_index in range(num_predict_tokens): - # Project and predict - z = self.emb[head_index](last_tokens) # b k d - states = self.proj[head_index](previous_hidden_states) + # # Project and predict + # z = self.emb[head_index](last_tokens) # b k d + # states = self.proj[head_index](previous_hidden_states) - # Weighted add of state_weight*state and emb_weight*z - # Let subsequent LN take care of denominator - # state_weight is close to 1, so shouldn't be any precision issues - states.add_(z, alpha=self.emb_weight / self.state_weight) + # # Weighted add of state_weight*state and emb_weight*z + # # Let subsequent LN take care of denominator + # # state_weight is close to 1, so shouldn't be any precision issues + # states.add_(z, alpha=self.emb_weight / self.state_weight) - states = self.activation(self.ln[head_index](states)) # b k d - previous_hidden_states = states - # TODO: not yet supporting top_k_tokens_per_head - states = states.flatten(0, 1) + # states = self.activation(self.ln[head_index](states)) # b k d + # previous_hidden_states = states + # # TODO: not yet supporting top_k_tokens_per_head + # states = states.flatten(0, 1) - logits = self.logits_processor(self.head[head_index], states, - sampling_metadata) + # logits = self.logits_processor(self.head[head_index], states, + # sampling_metadata) - output = self.sampler(logits, sampling_metadata) - last_tokens = output.sampled_token_ids - next_tokens.append(output) + # output = self.sampler(logits, sampling_metadata) + # last_tokens = output.sampled_token_ids + # next_tokens.append(output) - return next_tokens + # return next_tokens def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index c4548ee168bd..aa7c434a44ae 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -697,16 +697,12 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - # If the shape is the same, it means that we have already - # prune hidden states manually. - prune_hidden_states = hidden_states.size( - 0) != sampling_metadata.selected_token_indices.size(0) processed_logits = self.logits_processor( self.lm_head, hidden_states, sampling_metadata, self.embedding_bias, - prune_hidden_states=prune_hidden_states) + ) return processed_logits def load_weights( diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 2315f9dad5a5..8c4548ff7f7d 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -1,597 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from array import array -from dataclasses import dataclass -from typing import Optional - -import torch - -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, - SequenceGroupMetadata) -from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad) - -_SAMPLING_EPS = 1e-5 - - -@dataclass -class SequenceGroupToSample: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Sequence ids for the sequence group in a previous step. - seq_ids: list[int] - sampling_params: SamplingParams - # seq_id -> sequence data. - seq_data: dict[int, SequenceData] - # The length of the sequence (all tokens seen in the past + new token to - # compute attention) of the sequence group. None if it is in a decode - # stage. - seq_len: Optional[int] - # The length of new query tokens to compute in the current step. None if it - # is in a decode stage. The length of query_len <= seq_len if chunked - # prefill is enabled. - query_len: Optional[int] - # A random number generator for sampling. - generator: Optional[torch.Generator] - # True if the sequence group is in prefill stage. False if it is in a - # decode stage. - is_prompt: bool - # Query token indices from logits. to compute prompt logprob. Empty if - # prompt logprob is not required. - prompt_logprob_indices: list[int] - # Sample token indices from logits. Empty if sampling is not required. - sample_indices: list[int] - - @property - def do_sample(self): - return len(self.sample_indices) > 0 - - def __post_init__(self): - if len(self.prompt_logprob_indices) > 0: - assert self.sampling_params.prompt_logprobs is not None - if self.is_prompt: - assert self.seq_len is not None - assert self.query_len is not None - - -def gen_seq_group_to_sample_builder(num_seqs: int): - return lambda: SequenceGroupToSample( - seq_ids=[0] * num_seqs, - sampling_params=None, - seq_data=None, # type: ignore - seq_len=0, - query_len=0, - generator=None, - is_prompt=True, - prompt_logprob_indices=[], - sample_indices=[], - ) - - -class SamplingMetadataCache: - """Used to cache SamplingMetadata objects between scheduler iterations""" - - def __init__(self): - self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {} - - def get_cached_seq_group_to_sample(self, num_seqs): - if num_seqs not in self._seq_group_to_sample_cache: - self._seq_group_to_sample_cache[num_seqs] = PyObjectCache( - gen_seq_group_to_sample_builder(num_seqs)) - - obj = self._seq_group_to_sample_cache[num_seqs].get_object() - return obj - - def reset(self): - for cache in self._seq_group_to_sample_cache.values(): - cache.reset() - class SamplingMetadata: - """Metadata for input sequences. Used in sampler. - - The usage is as follows; - ``` - hidden_states = execute_model(...) - logits = hidden_states[sampling_metadata.selected_token_indices] - sample(logits) - - def sample(logits): - # Use categorized_sample_indices for sampling.... - ``` - - Args: - seq_groups: List of batched sequence groups. - selected_token_indices: (num_query_tokens_to_logprob). Indices to find - logits from the initial model output hidden states. - categorized_sample_indices: SamplingType -> token indices to sample. - Each token indices is 2D tensor of (num_indices, num_indices) where - the first item means the sample index within the returned logit - (before pruning padding), and the second item means the sample - index after pruning using selected_token_indices. - For example, if the returned logit is [1, 2, 3], and we select - [1, 2] for sampling, the pruned logit will be [2, 3]. In this case, - The first tuple is [1, 2] (sampled index within original logit), - and the second tuple is [0, 1] (sampled index within pruned logit). - num_prompts: Number of prompt sequence groups in seq_groups. - skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU - serialization of token outputs. - reuse_sampling_tensors: Indicates if we want to reuse sampling - tensors that are part of the sampler forward pass. Currently, - it is mainly used for multi-step decode. - - """ - - def __init__( - self, - seq_groups: list[SequenceGroupToSample], - selected_token_indices: torch.Tensor, - categorized_sample_indices: dict[SamplingType, torch.Tensor], - num_prompts: int, - skip_sampler_cpu_output: bool = False, - reuse_sampling_tensors: bool = False, - ) -> None: - self.seq_groups = seq_groups - self.selected_token_indices = selected_token_indices - self.categorized_sample_indices = categorized_sample_indices - self.num_prompts = num_prompts - self.skip_sampler_cpu_output = skip_sampler_cpu_output - self.reuse_sampling_tensors = reuse_sampling_tensors - - @staticmethod - def prepare( - seq_group_metadata_list: list[SequenceGroupMetadata], - seq_lens: list[int], - query_lens: list[int], - device: str, - pin_memory: bool, - generators: Optional[dict[str, torch.Generator]] = None, - cache: Optional[SamplingMetadataCache] = None, - ) -> "SamplingMetadata": - ( - seq_groups, - selected_token_indices, - categorized_sample_indices, - num_prompts, - ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, - device, generators, cache) - selected_token_indices = async_tensor_h2d( - selected_token_indices, - dtype=torch.long, - target_device=device, - pin_memory=pin_memory, - ) - categorized_sample_indices = { - t: - async_tensor_h2d( - seq_ids, - dtype=torch.int, - target_device=device, - pin_memory=pin_memory, - ) - for t, seq_ids in categorized_sample_indices.items() - } - - sampling_metadata = SamplingMetadata( - seq_groups=seq_groups, - selected_token_indices=selected_token_indices, - categorized_sample_indices=categorized_sample_indices, - num_prompts=num_prompts, - ) - return sampling_metadata - - def __repr__(self) -> str: - return ( - "SamplingMetadata(" - f"seq_groups={self.seq_groups}, " - f"selected_token_indices={self.selected_token_indices}, " - f"categorized_sample_indices={self.categorized_sample_indices})") - - -def _prepare_seq_groups( - seq_group_metadata_list: list[SequenceGroupMetadata], - seq_lens: list[int], - query_lens: list[int], - device: str, - generators: Optional[dict[str, torch.Generator]] = None, - cache: Optional[SamplingMetadataCache] = None, -) -> tuple[ - list[SequenceGroupToSample], - list[int], - dict[SamplingType, list[int]], - int, -]: - """Prepare sequence groups and indices for sampling. - - Args: - seq_group_metadata_list: A list of sequence group to batch. - seq_lens: A list of sequence lens per sequence group. - Index of prompt len should match with seq_group_metadata_list. - query_lens: A list of query lengths. Prompt lens include the length - of entire prompt tokens, and it could be shorter. - device: A device to use for random number generators, - `SequenceGroupToSample.generator`. - generators: A store of per-request random number generators used - for seeded requests. - - Returns: - seq_groups: A list of sequence group to sample. - selected_token_indices: See the definition from `SamplingMetadata`. - categorized_sample_indices: See the definition from `SamplingMetadata`. - num_prompts: Total number of prompts from `seq_group_metadata_list`. - """ - # Batched sequence groups for the current model forward stsep. - seq_groups: list[SequenceGroupToSample] = [] - # A list of token indices to sample/compute logprob. It is used to - # prune the outcome logits from the model for the performance. - selected_token_indices: list[int] = [] - # Used for selected_token_indices. - model_output_idx = 0 - - # Sampling type -> ( - # indices to sample/prompt logprob within pruned output logits, - # indices to sample within pruned logits) - categorized_sample_indices: dict[SamplingType, list[int]] = { - t: [] - for t in SamplingType - } - # Index of logits to compute logprob. Logits include both prompt logprob - # and sample logprob indices. - logit_idx = 0 - # Total number of prompts from given sequence groups. - num_prompts = 0 - - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = seq_group_metadata.seq_data.keys() - - if cache is not None: - sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids)) - - for j, seq_id in enumerate(seq_ids): - sample_obj.seq_ids[j] = seq_id - - sample_obj.prompt_logprob_indices.clear() - sample_obj.sample_indices.clear() - - sampling_params = seq_group_metadata.sampling_params - is_prompt = seq_group_metadata.is_prompt - generator: Optional[torch.Generator] = None - # If the current seq group is in decode stage, it is None. - seq_len: Optional[int] = None - query_len: Optional[int] = None - prompt_logprob_indices: list[int] = (sample_obj.prompt_logprob_indices - if cache is not None else []) - sample_indices: list[int] = (sample_obj.sample_indices - if cache is not None else []) - do_sample = seq_group_metadata.do_sample - - if seq_group_metadata.is_prompt: - if sampling_params.seed is not None: - generator = torch.Generator(device=device).manual_seed( - sampling_params.seed) - if generators is not None: - generators[seq_group_metadata.request_id] = generator - - num_prompts += 1 - num_prefill_sample = len(seq_ids) - assert num_prefill_sample == 1 - assert query_lens is not None and seq_lens is not None - query_len, seq_len = query_lens[i], seq_lens[i] - # If we need sampling, exclude num_prefill_sample tokens from - # prompt logprob. - prompt_logprob_len = (query_len - num_prefill_sample - if do_sample else query_len) - sample_len = num_prefill_sample if do_sample else 0 - else: - # Decode - prompt_logprob_len = 0 - query_len = query_lens[i] if query_lens is not None and len( - query_lens) > 0 else 1 - sample_len = len(seq_ids) * query_len if do_sample else 0 - - if sampling_params.seed is not None and generators is not None: - generator = generators.get(seq_group_metadata.request_id) - - # Update indices to select from the model output. - """ - This blocks computes selected_token_indices which is used in the - following way. - - hidden_states = model(...) - logits = hidden_states[selected_token_indices] - """ - - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(model_output_idx, model_output_idx + prompt_logprob_len)) - model_output_idx += prompt_logprob_len - if do_sample: - selected_token_indices.extend( - range(model_output_idx, model_output_idx + sample_len)) - model_output_idx += sample_len - - # We now find indices for logprob computation and sampling. - """ - This block computes categorized_sample_indices which is used in the - following way. - - hidden_states = model(...) - logits = hidden_states[selected_token_indices] - def sample(logits): - # Use categorized_sample_indices for sampling. - # prompt_logprob_indices to find prompt logprob indices. - # sample_indices to find sample indices. - """ - - if sampling_params.prompt_logprobs is not None: - prompt_logprob_indices.extend( - range(logit_idx, logit_idx + prompt_logprob_len)) - logit_idx += prompt_logprob_len - if do_sample: - sample_indices.extend(range(logit_idx, logit_idx + sample_len)) - categorized_sample_indices[sampling_params.sampling_type].extend( - list(range(logit_idx, logit_idx + sample_len))) - logit_idx += sample_len - - if cache is not None: - sample_obj.sampling_params = sampling_params - sample_obj.seq_data = seq_group_metadata.seq_data - sample_obj.seq_len = seq_len - sample_obj.query_len = query_len - sample_obj.generator = generator - sample_obj.is_prompt = is_prompt - else: - sample_obj = SequenceGroupToSample( - seq_ids=list(seq_ids), - sampling_params=sampling_params, - seq_data=seq_group_metadata.seq_data, - seq_len=seq_len, - query_len=query_len, - generator=generator, - is_prompt=is_prompt, - prompt_logprob_indices=list(prompt_logprob_indices), - sample_indices=list(sample_indices), - ) - - seq_groups.append(sample_obj) - - if cache is not None: - cache.reset() - - return (seq_groups, selected_token_indices, categorized_sample_indices, - num_prompts) - - -@dataclass -class SamplingTensors: - """Tensors for sampling.""" - - temperatures: torch.Tensor - top_ps: torch.Tensor - top_ks: torch.Tensor - min_ps: torch.Tensor - presence_penalties: torch.Tensor - frequency_penalties: torch.Tensor - repetition_penalties: torch.Tensor - prompt_tokens: torch.Tensor - output_tokens: torch.Tensor - - @classmethod - def from_sampling_metadata( - cls, - sampling_metadata: "SamplingMetadata", - vocab_size: int, - device: torch.device, - dtype: torch.dtype, - ) -> tuple["SamplingTensors", bool, bool, bool]: - prompt_tokens: list[array] = [] - output_tokens: list[array] = [] - top_ks: list[int] = [] - temperatures: list[float] = [] - top_ps: list[float] = [] - min_ps: list[float] = [] - presence_penalties: list[float] = [] - frequency_penalties: list[float] = [] - repetition_penalties: list[float] = [] - do_penalties = False - do_top_p_top_k = False - do_min_p = False - - assert sampling_metadata.seq_groups is not None - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - temperature = sampling_params.temperature - p = sampling_params.presence_penalty - f = sampling_params.frequency_penalty - r = sampling_params.repetition_penalty - top_p = sampling_params.top_p - min_p = sampling_params.min_p - - # k should not be greater than the vocab size. - top_k = min(sampling_params.top_k, vocab_size) - top_k = vocab_size if top_k < 1 else top_k - if temperature < _SAMPLING_EPS: - # NOTE: Zero temperature means deterministic sampling - # (i.e., greedy sampling or beam search). - # Set the temperature to 1 to avoid division by zero. - temperature = 1.0 - if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS - or top_k != vocab_size): - do_top_p_top_k = True - if not do_min_p and min_p > _SAMPLING_EPS: - do_min_p = True - if not do_penalties and (abs(p) >= _SAMPLING_EPS - or abs(f) >= _SAMPLING_EPS - or abs(r - 1.0) >= _SAMPLING_EPS): - do_penalties = True - - is_prompt = seq_group.is_prompt - if is_prompt and sampling_params.prompt_logprobs is not None: - # For tokens in the prompt that we only need to get - # their logprobs - query_len = seq_group.query_len - assert query_len is not None - prefill_len = len(seq_group.prompt_logprob_indices) - temperatures += [temperature] * prefill_len - top_ps += [top_p] * prefill_len - top_ks += [top_k] * prefill_len - min_ps += [min_p] * prefill_len - presence_penalties += [0] * prefill_len - frequency_penalties += [0] * prefill_len - repetition_penalties += [1] * prefill_len - - if seq_group.do_sample: - sample_lens = len(seq_group.sample_indices) - assert sample_lens >= len(seq_ids) - temperatures += [temperature] * sample_lens - top_ps += [top_p] * sample_lens - top_ks += [top_k] * sample_lens - min_ps += [min_p] * sample_lens - presence_penalties += [p] * sample_lens - frequency_penalties += [f] * sample_lens - repetition_penalties += [r] * sample_lens - - if do_penalties: - for seq_group in sampling_metadata.seq_groups: - seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - if (seq_group.is_prompt - and sampling_params.prompt_logprobs is not None): - prefill_len = len(seq_group.prompt_logprob_indices) - prompt_tokens.extend( - array(VLLM_TOKEN_ID_ARRAY_TYPE) - for _ in range(prefill_len)) - output_tokens.extend( - array(VLLM_TOKEN_ID_ARRAY_TYPE) - for _ in range(prefill_len)) - if seq_group.do_sample: - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - prompt_tokens.append(seq_data.prompt_token_ids_array) - output_tokens.append(seq_data.output_token_ids_array) - - sampling_tensors = SamplingTensors.from_lists( - temperatures, - top_ps, - top_ks, - min_ps, - presence_penalties, - frequency_penalties, - repetition_penalties, - prompt_tokens, - output_tokens, - vocab_size, - device, - dtype, - ) - return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) - - @classmethod - def from_lists( - cls, - temperatures: list[float], - top_ps: list[float], - top_ks: list[int], - min_ps: list[float], - presence_penalties: list[float], - frequency_penalties: list[float], - repetition_penalties: list[float], - prompt_tokens: list[array], - output_tokens: list[array], - vocab_size: int, - device: torch.device, - dtype: torch.dtype, - ) -> "SamplingTensors": - # Note that the performance will be very bad without - # pinned memory. - pin_memory = is_pin_memory_available() - - do_penalties = prompt_tokens or output_tokens - - if do_penalties: - prompt_t = make_tensor_with_pad( - prompt_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - ) - output_t = make_tensor_with_pad( - output_tokens, - vocab_size, - device="cpu", - dtype=torch.int64, - pin_memory=pin_memory, - ) - else: - empty_tensor = torch.empty(0, device=device, dtype=torch.long) - prompt_t = empty_tensor - output_t = empty_tensor - - temperatures_t = torch.tensor( - temperatures, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - top_ps_t = torch.tensor( - top_ps, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - min_ps_t = torch.tensor( - min_ps, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - presence_penalties_t = torch.tensor( - presence_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - frequency_penalties_t = torch.tensor( - frequency_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - repetition_penalties_t = torch.tensor( - repetition_penalties, - device="cpu", - dtype=dtype, - pin_memory=pin_memory, - ) - top_ks_t = torch.tensor( - top_ks, - device="cpu", - dtype=torch.int, - pin_memory=pin_memory, - ) - # Because the memory is pinned, we can do non-blocking - # transfer to device. - - return cls( - temperatures=temperatures_t.to(device=device, non_blocking=True), - top_ps=top_ps_t.to(device=device, non_blocking=True), - top_ks=top_ks_t.to(device=device, non_blocking=True), - min_ps=min_ps_t.to(device=device, non_blocking=True), - presence_penalties=presence_penalties_t.to(device=device, - non_blocking=True), - frequency_penalties=frequency_penalties_t.to(device=device, - non_blocking=True), - repetition_penalties=repetition_penalties_t.to(device=device, - non_blocking=True), - prompt_tokens=prompt_t.to(device=device, non_blocking=True), - output_tokens=output_t.to(device=device, non_blocking=True), - ) + # Placeholder until it can be safely removed. + pass diff --git a/vllm/sequence.py b/vllm/sequence.py index 24114c0bb792..a6c194fbac0b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,28 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sequence and its related classes.""" -import copy -import enum -from abc import ABC, abstractmethod -from array import array -from collections import defaultdict -from collections.abc import Mapping -from collections.abc import Sequence as GenericSequence -from dataclasses import dataclass, field -from functools import reduce -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Union import msgspec import torch -from vllm.inputs import SingletonInputs -from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs -from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import RequestOutputKind, SamplingParams - if TYPE_CHECKING: - from vllm.lora.request import LoRARequest from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorOutput) else: @@ -34,50 +19,6 @@ VLLM_INVALID_TOKEN_ID = -1 -def array_full(token_id: int, count: int): - """[`array`][] equivalent of [numpy.full][].""" - return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count - - -class SequenceStatus(enum.IntEnum): - """Status of a sequence.""" - WAITING = 0 - RUNNING = 1 - SWAPPED = 2 - # Note: anything after SWAPPED (2) will be considered - # as a finished status. - FINISHED_STOPPED = 3 - FINISHED_LENGTH_CAPPED = 4 - FINISHED_ABORTED = 5 - FINISHED_IGNORED = 6 - - @staticmethod - def is_finished(status: "SequenceStatus") -> bool: - return status > SequenceStatus.SWAPPED - - @staticmethod - def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: - if status == SequenceStatus.FINISHED_STOPPED: - finish_reason = "stop" - elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: - finish_reason = "length" - elif status == SequenceStatus.FINISHED_ABORTED: - finish_reason = "abort" - elif status == SequenceStatus.FINISHED_IGNORED: - # The ignored sequences are the sequences whose prompt lengths - # are longer than the model's length cap. Therefore, the stop - # reason should also be "length" as in OpenAI API. - finish_reason = "length" - else: - finish_reason = None - return finish_reason - - -class SequenceStage(enum.Enum): - PREFILL = enum.auto() - DECODE = enum.auto() - - @dataclass class RequestMetrics: """Metrics associated with a request. @@ -107,971 +48,12 @@ class RequestMetrics: model_execute_time: Optional[float] = None -class SequenceDataDelta( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Delta SequenceData to send to workers per step.""" - # A new token to be appended to existing SequenceData. - new_output_token_ids: list[int] - # Overwriting existing `cumulative_logprob` - new_cumulative_logprob: float - # Overwriting existing `num_computed_tokens`. - new_num_computed_tokens: int - # Overwriting existing `stage`. - new_stage: SequenceStage - - -class SequenceData(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] - """Data associated with a sequence.""" - # NOTE: we cannot use Union[list, array] because msgspec cannot support - # union of 2 list types. - _prompt_token_ids: array - _output_token_ids: array = msgspec.field( - default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) - - _prompt_embeds: Optional[torch.Tensor] = None - _output_embeds: Optional[torch.Tensor] = None - - ### The below fields should not be passed as an argument ### - _cumulative_logprob: float = 0.0 - _prompt_token_ids_tuple: tuple[int, - ...] = msgspec.field(default_factory=tuple) - # The number of tokens that are computed (that run against the model). - _num_computed_tokens: int = 0 - # The number of tokens with prefix cache hit. - _num_cached_tokens: int = 0 - _stage: SequenceStage = SequenceStage.PREFILL - _cached_all_token_ids: list[int] = msgspec.field(default_factory=list) - _cached_all_token_embeds: Optional[torch.Tensor] = None - - # It is used to get delta input. It is reset when `get_delta_and_reset` - # is called. - _new_appended_tokens: list[int] = msgspec.field(default_factory=list) - - # It is used to compute mrope_position_ids. - _mrope_position_delta: Optional[int] = None - - @staticmethod - def from_prompt_token_counts( - *token_counts: tuple[int, int]) -> "SequenceData": - """ - Construct a [`SequenceData`][vllm.sequence.SequenceData] instance - by concatenating prompt token sequences. - - Each tuple represents one token sequence, expressed in the form - `(token_id, count)`. - """ - if len(token_counts) == 0: - return SequenceData.from_seqs([]) - - prompt_token_ids_arr = reduce( - array.__iadd__, - (array_full(token_id, count) for token_id, count in token_counts), - ) - - return SequenceData(prompt_token_ids_arr) - - @staticmethod - def from_seqs( - prompt_token_ids: GenericSequence[int], - output_token_ids: Optional[GenericSequence[int]] = None, - *, - prompt_embeds: Optional[torch.Tensor] = None, - ) -> "SequenceData": - """ - Construct a [`SequenceData`][vllm.sequence.SequenceData] instance - from prompt and output token sequences. - """ - prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - prompt_token_ids) - - if output_token_ids is None: - return SequenceData(prompt_token_ids_arr, - _prompt_embeds=prompt_embeds) - - output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, - output_token_ids) - - return SequenceData(prompt_token_ids_arr, - _output_token_ids=output_token_ids_arr, - _prompt_embeds=prompt_embeds) - - def __post_init__(self) -> None: - assert self._prompt_token_ids.typecode == "l" - assert self._output_token_ids.typecode == "l" - self._prompt_token_ids_tuple: tuple[int, ...] = tuple( - self._prompt_token_ids) - self._update_cached_all_tokens() - if self._prompt_embeds is not None: - self._update_cached_all_token_embeds() - - def _update_cached_all_tokens(self): - assert isinstance(self._prompt_token_ids, array) - assert isinstance(self._output_token_ids, array) - self._cached_all_token_ids: list[int] = list(self._prompt_token_ids + - self._output_token_ids) - - def _update_cached_all_token_embeds(self): - assert isinstance(self._prompt_embeds, torch.Tensor) - self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds - if self._output_embeds is not None: - self._cached_all_token_embeds = torch.cat( - (self._cached_all_token_embeds, self._output_embeds), dim=0) - - @property - def cumulative_logprob(self) -> float: - """The cumulative log probability of the output.""" - return self._cumulative_logprob - - @property - def prompt_token_ids(self) -> tuple[int, ...]: - """The token IDs of the prompt.""" - return self._prompt_token_ids_tuple - - @prompt_token_ids.setter - def prompt_token_ids(self, new_prompt_token_ids) -> None: - raise NotImplementedError - - @property - def prompt_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - return self._prompt_token_ids - - @property - def output_token_ids(self) -> tuple[int, ...]: - """The token IDs of the output.""" - return tuple(self._output_token_ids) - - @output_token_ids.setter - def output_token_ids(self, - new_output_token_ids: GenericSequence[int]) -> None: - self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - new_output_token_ids) - self._update_cached_all_tokens() - - @property - def output_embeds(self) -> Optional[torch.Tensor]: - return self._output_embeds - - @output_embeds.setter - def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None: - self._output_token_embeds = new_output_token_embeds - self._update_cached_all_token_embeds() - - @property - def output_token_ids_array(self) -> array: - """Return the prompt token ids in array type. - - Note that the array is in "I" type, and it is not compatible - with torch.long (2 bytes vs 4 bytes). So beware of the usage. - """ - assert isinstance(self._output_token_ids, array) - return self._output_token_ids - - @property - def prompt_embeds(self) -> Optional[torch.Tensor]: - return self._prompt_embeds - - @prompt_embeds.setter - def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None: - self._prompt_embeds = prompt_embeds - self._update_cached_all_token_embeds() - - @property - def mrope_position_delta(self) -> Optional[int]: - return self._mrope_position_delta - - @mrope_position_delta.setter - def mrope_position_delta(self, new_mrope_position_delta): - self._mrope_position_delta = new_mrope_position_delta - - def append_token_id(self, - token_id: int, - logprob: float, - token_embed: Optional[torch.Tensor] = None) -> None: - self._output_token_ids.append(token_id) - self._new_appended_tokens.append(token_id) - self._cached_all_token_ids.append(token_id) - self._cumulative_logprob += logprob - if token_embed is not None: - # Do not pass in with batch or sequence dimensions - assert token_embed.ndim == 1 - token_embed = token_embed.detach().cpu().unsqueeze(0) - if self._output_embeds is None: - self._output_embeds = token_embed - else: - self._output_embeds = torch.cat( - (self._output_embeds, token_embed), dim=0) - assert self._cached_all_token_embeds is not None - self._cached_all_token_embeds = torch.cat( - (self._cached_all_token_embeds, - token_embed.to(device=self._cached_all_token_embeds.device)), - dim=0) - - def get_len(self) -> int: - return len(self._output_token_ids) + len(self._prompt_token_ids) - - def get_prompt_len(self) -> int: - return len(self._prompt_token_ids) - - def get_output_len(self) -> int: - return len(self._output_token_ids) - - def get_token_ids(self) -> list[int]: - return self._cached_all_token_ids - - def get_token_embeddings(self) -> Optional[torch.Tensor]: - return self._cached_all_token_embeds - - def get_prefix_token_ids( - self, num_tokens: int - ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]: - """Get prefix tokens, and make the return value hashable""" - prompt_length = self.get_prompt_len() - if num_tokens > prompt_length: - return (self._prompt_token_ids_tuple, - tuple(self._output_token_ids[:num_tokens - prompt_length])) - else: - return (self._prompt_token_ids_tuple[:num_tokens], None) - - def get_num_computed_tokens(self) -> int: - """Return the number of prefill tokens that are already computed.""" - return self._num_computed_tokens - - def update_num_computed_tokens(self, num_new_computed_tokens: int): - """Update number of tokens computed so far.""" - self._num_computed_tokens += num_new_computed_tokens - assert self._num_computed_tokens <= self.get_len(), ( - self._num_computed_tokens, self.get_len()) - # If all tokens are computed, it means it is in decoding phase. - if self.get_num_uncomputed_tokens() == 0: - self._stage = SequenceStage.DECODE - - def get_num_cached_tokens(self) -> int: - """Return the number of tokens with prefix cache hit.""" - return self._num_cached_tokens - - def update_num_cached_tokens(self, num_cached_tokens: int): - """Update the number of tokens with prefix cache hit.""" - self._num_cached_tokens = num_cached_tokens - - def reset_state_for_recompute(self) -> None: - """Reset the number of computed tokens from this sequence. It is - supposed to be called when a sequence needs to be started from - the beginning again (e.g., sequence is preempted). - """ - self._num_computed_tokens = 0 - self._stage = SequenceStage.PREFILL - self._new_appended_tokens = [] - - def get_num_uncomputed_tokens(self) -> int: - """Return the number of prefill tokens that are not computed.""" - # we use `get_len()` which includes prompt_len + output_len instead - # of prompt_len here. This is because during recompute we need to - # prefill for both prompt and output. - return self.get_len() - self.get_num_computed_tokens() - - def get_last_token_id(self) -> int: - if not self._output_token_ids: - return self._prompt_token_ids[-1] - return self._output_token_ids[-1] - - def get_prompt_token_ids(self) -> tuple[int, ...]: - return self.prompt_token_ids - - def get_output_token_ids(self) -> tuple[int, ...]: - return self.output_token_ids - - def get_delta_and_reset(self) -> SequenceDataDelta: - delta = SequenceDataDelta(self._new_appended_tokens, - self._cumulative_logprob, - self.get_num_computed_tokens(), self.stage) - # Reset delta state. - self._new_appended_tokens = [] - return delta - - def apply_delta(self, delta: SequenceDataDelta): - self._num_computed_tokens = delta.new_num_computed_tokens - self._cumulative_logprob = delta.new_cumulative_logprob - self._stage = delta.new_stage - self._output_token_ids.extend(delta.new_output_token_ids) - self._cached_all_token_ids.extend(delta.new_output_token_ids) - - @property - def stage(self) -> SequenceStage: - return self._stage - - def __repr__(self) -> str: - return (f"SequenceData(" - f"prompt_token_ids={self._prompt_token_ids}, " - f"prompt_embeds.shape=" - f"{getattr(self._prompt_embeds, 'shape', None)}, " - f"output_token_ids={self.output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"get_num_computed_tokens={self.get_num_computed_tokens()})") - - -class Sequence: - """Stores the data, status, and block information of a sequence. - - The sequence is constructed from the - [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] (for decoder-only) - or [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] - (for encoder-decoder) instance passed in through the `inputs` - constructor argument. - - Args: - seq_id: The ID of the sequence. - inputs: The inputs of the sequence. - block_size: The block size of the sequence. Should be the same as the - block size used by the block manager and cache engine. - eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. - lora_request: LoRA request. - """ - - def __init__( - self, - seq_id: int, - inputs: SingletonInputs, - block_size: int, - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - ) -> None: - self.seq_id = seq_id - self.inputs = inputs - self.block_size = block_size - self.eos_token_id = eos_token_id - self.lora_request = lora_request - - self.data = SequenceData.from_seqs( - self.prompt_token_ids, - prompt_embeds=self.inputs["prompt_embeds"] - if self.inputs["type"] == "embeds" else None) - self.output_logprobs: SampleLogprobs = [] - self.output_text = "" - - self.status = SequenceStatus.WAITING - self.stop_reason: Union[int, str, None] = None - - # These are used to keep track of delta outputs - self._last_output_token_ids_offset: int = 0 - self._last_output_text_offset: int = 0 - - # Used for incremental detokenization - self.prefix_offset = 0 - self.read_offset = 0 - # Input + output tokens - self.tokens: Optional[list[str]] = None - - @property - def n_blocks(self) -> int: - return (self.get_len() + self.block_size - 1) // self.block_size - - @property - def prompt(self) -> Optional[str]: - if self.inputs["type"] == "embeds": - return None - return self.inputs.get("prompt") - - @property - def prompt_token_ids(self) -> list[int]: - if self.inputs["type"] == "embeds": - return [0] * len(self.inputs["prompt_embeds"]) - return self.inputs["prompt_token_ids"] - - @property - def multi_modal_data(self) -> MultiModalKwargs: - if self.inputs["type"] == "multimodal": - return self.inputs["mm_kwargs"].get_data() - - return MultiModalKwargs() - - @property - def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - if self.inputs["type"] == "multimodal": - return self.inputs["mm_placeholders"] - - return {} - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - def get_output_text_to_return(self, buffer_length: int, - delta: bool) -> str: - """If delta is True, only new text since the last call to - this method is returned""" - - # We return the full output text if the sequence is finished. - truncate = buffer_length and not self.is_finished() - if not delta: - return self.output_text[:-buffer_length] if truncate else ( - self.output_text) - length = len(self.output_text) - if truncate: - length -= buffer_length - last_offset = self._last_output_text_offset - if last_offset < length: - self._last_output_text_offset = length - return self.output_text[last_offset:length] - return "" - - def get_output_token_ids_to_return( - self, delta: bool) -> Union[GenericSequence[int], int]: - """If delta is True, only new tokens since the last call to - this method are returned""" - if not delta: - return self.get_output_token_ids() - - output_len = self.get_output_len() - - # Get the number of new tokens - num_new_tokens = output_len - self._last_output_token_ids_offset - self._last_output_token_ids_offset = output_len - - # Return new tokens - if num_new_tokens == 1: - # Optimization for single decode token case - # (which is what we have most of the time) - return self.data._cached_all_token_ids[-1] - - if num_new_tokens == 0: - return [] - - return self.data._cached_all_token_ids[-num_new_tokens:] - - def hash_of_block(self, logical_idx: int) -> int: - # TODO This can produce incorrect hash when block size > prompt size - - # Compute the number of tokens in the sequence - # TODO: The current hashing function is O(L^2). We should optimize - # this in the future. - num_tokens = self.num_hashed_tokens_of_block(logical_idx) - hashed_tokens = self.data.get_prefix_token_ids(num_tokens) - return hash((hashed_tokens, self.lora_int_id)) - - def extra_hash(self) -> Optional[int]: - """ - This function computes an extra hash for a sequence, specifically - designed for prefix caching mode. The final sequence hash is determined - by applying token_ids from the sequence's blocks. - """ - if self.lora_int_id == 0: - return None - - # NOTE: If there are additional factors influencing the block aside from - # token_ids, include them as input parameters to the hash. - return hash(self.lora_int_id) - - def num_hashed_tokens_of_block(self, logical_idx: int): - return logical_idx * self.block_size + self.block_size - - def reset_state_for_recompute(self): - """Reset the sequence states for recomputation.""" - self.data.reset_state_for_recompute() - - def append_token_id(self, - token_id: int, - logprobs: dict[int, Logprob], - token_embed: Optional[torch.Tensor] = None) -> None: - assert token_id in logprobs - self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob, - token_embed) - - def get_len(self) -> int: - return self.data.get_len() - - def get_prompt_len(self) -> int: - return self.data.get_prompt_len() - - def get_output_len(self) -> int: - return self.data.get_output_len() - - def get_token_ids(self) -> list[int]: - return self.data.get_token_ids() - - def get_prompt_token_ids(self) -> tuple[int, ...]: - return self.data.get_prompt_token_ids() - - def get_last_token_id(self) -> int: - return self.data.get_last_token_id() - - def get_output_token_ids(self) -> tuple[int, ...]: - return self.data.get_output_token_ids() - - def get_cumulative_logprob(self) -> float: - return self.data.cumulative_logprob - - def is_finished(self) -> bool: - return SequenceStatus.is_finished(self.status) - - def fork(self, new_seq_id: int) -> "Sequence": - new_seq = copy.deepcopy(self) - new_seq.seq_id = new_seq_id - return new_seq - - def get_num_new_tokens(self) -> int: - """Get the number of new tokens to be computed. - - Returns: - The new number of tokens to be computed. I.e., 1 for decode, or - the remaining prompt size for prefill. - """ - if self.data.stage == SequenceStage.DECODE: - return 1 - return self.data.get_num_uncomputed_tokens() - - def get_num_computed_tokens(self) -> int: - return self.data.get_num_computed_tokens() - - def is_prefill(self) -> bool: - return self.data.stage == SequenceStage.PREFILL - - def __repr__(self) -> str: - return (f"Sequence(seq_id={self.seq_id}, " - f"status={self.status.name}, " - f"num_blocks={self.n_blocks})") - - -class SequenceGroupState(msgspec.Struct, - omit_defaults=True): # type: ignore[call-arg] - """Mutable state tied to a specific sequence group""" - - # for multi-step decoding - num_steps: int = 1 - current_step: int = 0 - - @property - def remaining_steps(self) -> int: - return self.num_steps - self.current_step - - -class SequenceGroup: - """A group of sequences that are generated from the same prompt. - - Args: - request_id: The ID of the request. - seqs: The list of sequences. - sampling_params: The sampling parameters used to generate the outputs. - arrival_time: The arrival time of the request. - lora_request: LoRA request. - pooling_params: The parameters used to generate the pooler - for a pooling model. - pooled_data: The extracted hidden states from a pooling model. - encoder_seq: Optional, the single encoder sequence. Should be None - unless you are working with an encoder/decoder model. - trace_headers: OpenTelemetry trace headers. - priority: User-defined priority of the request. - draft_size: The number of speculative tokens plus one from the target - model; equal to max number of tokens a step can generate - for single-draft speculative decoding but larger than - that for multi-draft SD (currently not supported). - """ - - def __init__(self, - request_id: str, - seqs: list[Sequence], - arrival_time: float, - sampling_params: Optional[SamplingParams] = None, - lora_request: Optional[LoRARequest] = None, - pooling_params: Optional[PoolingParams] = None, - pooled_data: Optional[torch.Tensor] = None, - encoder_seq: Optional[Sequence] = None, - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, - draft_size: int = 1) -> None: - self.request_id = request_id - self.seqs = seqs - self.first_seq = seqs[0] - self.arrival_time = arrival_time - self.is_single_seq = len(seqs) == 1 - self.seqs_dict = {seq.seq_id: seq for seq in seqs} - - self.sampling_params = sampling_params - self.metrics = RequestMetrics(arrival_time=arrival_time, - last_token_time=arrival_time, - first_scheduled_time=None, - first_token_time=None, - time_in_queue=None) - self.last_token_latency = 0.0 - self.lora_request = lora_request - self.prompt_logprobs: Optional[PromptLogprobs] = None - self.state = SequenceGroupState() - self.pooling_params = pooling_params - self.pooled_data = pooled_data - self.encoder_seq = encoder_seq - self.trace_headers = trace_headers - self.priority = priority - - self.cached_request_output = None - - @property - def prompt(self) -> Optional[str]: - return self.first_seq.prompt - - @property - def prompt_token_ids(self) -> list[int]: - return self.first_seq.prompt_token_ids - - @property - def encoder_prompt(self) -> Optional[str]: - # There are either 0 or 1 encoder sequences - # If one is present, its prompt is distinct - # from the decoder's. - return (self.encoder_seq.prompt - if self.encoder_seq is not None else None) - - @property - def encoder_prompt_token_ids(self) -> Optional[list[int]]: - # There are either 0 or 1 encoder sequences - # If one is present, its prompt token ids are - # distinct from the decoder's. - return (self.encoder_seq.prompt_token_ids - if self.encoder_seq is not None else None) - - @property - def multi_modal_data(self) -> MultiModalKwargs: - if self.first_seq.multi_modal_data: - return self.first_seq.multi_modal_data - elif self.encoder_seq is not None: - return self.encoder_seq.multi_modal_data - return MultiModalKwargs() - - @property - def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: - if self.first_seq.multi_modal_data: - return self.first_seq.multi_modal_placeholders - elif self.encoder_seq is not None: - return self.encoder_seq.multi_modal_placeholders - return {} - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - def set_last_token_time(self, now: float) -> None: - """Sets the last token time for Request level timings.""" - # If still in prefill phase, assertion fails. - assert not self.is_prefill(), ( - "seq_group.set_last_token_time() should not be called " - "if the seq_group is in prefill phase.") - self.last_token_latency = now - self.metrics.last_token_time - self.metrics.last_token_time = now - - def get_last_token_latency(self) -> float: - """Returns the latency of the last token.""" - assert not self.is_prefill(), ( - "seq_group.get_last_token_latency() should not be called " - "if the seq_group is in prefill phase.") - return self.last_token_latency - - def maybe_set_first_token_time(self, time: float) -> None: - """Sets the first token time for Request level timings.""" - # Note: in a case where a sequence_group is swapped and - # recomputed, the time between iterations is counted - # in TPOT, rather than recalculating TTFT (since from the ) - # POV of the user, there is simply a long generation delay. - if (self.metrics.first_token_time is None - and self.first_seq.get_output_len() == 1): - self.metrics.first_token_time = time - - def maybe_set_first_scheduled_time(self, time: float) -> None: - """Sets the first scheduled time and time in queue for Request - level timings.""" - if self.metrics.first_scheduled_time is None: - self.metrics.first_scheduled_time = time - self.metrics.time_in_queue = time - self.metrics.arrival_time - - def set_finished_time(self, time: Optional[float]) -> None: - """Sets the finished time for Request level timings.""" - self.metrics.finished_time = time - - def get_max_num_running_seqs(self) -> int: - """The maximum number of sequences running in parallel in the remaining - lifetime of the request.""" - if self.is_single_seq: - return 0 if self.first_seq.is_finished() else 1 - return self.num_seqs() - self.num_finished_seqs() - - def get_seqs( - self, - status: Optional[SequenceStatus] = None, - ) -> list[Sequence]: - if status is None: - return self.seqs - - if self.is_single_seq: - return self.seqs if self.first_seq.status == status else [] - - return [seq for seq in self.seqs if seq.status == status] - - def is_encoder_decoder(self) -> bool: - return self.encoder_seq is not None - - def get_encoder_seq(self) -> Optional[Sequence]: - return self.encoder_seq - - def get_finished_seqs(self) -> list[Sequence]: - if self.is_single_seq: - return self.seqs if self.first_seq.is_finished() else [] - - return [seq for seq in self.seqs if seq.is_finished()] - - def update_num_computed_tokens(self, num_new_computed_tokens: int): - """Update number of tokens computed so far.""" - for seq in self.seqs: - if not seq.is_finished(): - seq.data.update_num_computed_tokens(num_new_computed_tokens) - - def get_num_uncomputed_tokens(self) -> int: - num_uncomputed_tokens = 0 - for seq in self.seqs: - if not seq.is_finished(): - num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() - return num_uncomputed_tokens - - def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: - # Optimization. We don't need to call get_seqs if we don't need to - # filter by states. - if status is None: - return len(self.seqs) - - if self.is_single_seq: - return 1 if self.seqs[0].status == status else 0 - - return len(self.get_seqs(status)) - - def num_finished_seqs(self) -> int: - if self.is_single_seq: - return 1 if self.seqs[0].is_finished() else 0 - return len(self.get_finished_seqs()) - - def is_finished(self) -> bool: - if self.is_single_seq: - return self.first_seq.is_finished() - return all(seq.is_finished() for seq in self.seqs) - - def is_prefill(self) -> bool: - return self.first_seq.is_prefill() - - def __repr__(self) -> str: - return (f"SequenceGroup(request_id={self.request_id}, " - f"sampling_params={self.sampling_params}, " - f"num_seqs={len(self.seqs)})") - - def uses_prompt_embeds(self) -> bool: - """Returns True if the sequence group uses input embeds.""" - return any(seq.data.prompt_embeds is not None for seq in self.seqs) - - -class SequenceGroupMetadataDelta( - msgspec.Struct, - tag=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Delta of SequenceGroupMetadata. - - After sending the first SequenceGroupMetadata, vLLM scheduler - only sends delta to reduce the data payload size. - """ - seq_data_delta: dict[int, SequenceDataDelta] - request_id: str - block_tables: dict[int, list[int]] - is_prompt: bool - do_sample: bool = True - token_chunk_size: Optional[int] = None - computed_block_nums: Optional[list[int]] = None - state: Optional[SequenceGroupState] = msgspec.field( - default_factory=lambda: SequenceGroupState()) - - -class SequenceGroupMetadata( - msgspec.Struct, - tag=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] - omit_defaults=True): # type: ignore[call-arg] - """Metadata for a sequence group. Used to create `AttentionMetadata`. - - Attributes: - request_id: The ID of the request. - is_prompt: Whether the request is at prompt stage. - seq_data: The sequence data. (Seq id -> sequence data) - sampling_params: The sampling parameters used to generate the outputs. - block_tables: The block tables. (Seq id -> list of physical block - numbers) - do_sample: True if sampling is required. Sampling is not required when - e.g., prefill is chunked, and the current iteration only computes - query tokens for prefill, we don't need sampling. - pooling_params: Pooling parameters. - lora_request: LoRA request. - computed_block_nums: The block numbers that are already computed, - used in prefix caching. - state: Internal state tied to this sequence group. - token_type_ids: Token type IDs. - multi_modal_data: Multi modal data. - multi_modal_placeholders: Multi modal placeholders. - encoder_seq_data: Optional sequence data for encoder prompt - (SequenceGroup.encoder_seq). Should be None - unless you are working with an encoder/decoder - model. - cross_block_table: Optional cross-attention block table associated - with the encoder prompt - (SequenceGroup.encoder_seq). Should be None - unless you are working with an encoder/decoder - model. - """ - - request_id: str - is_prompt: bool - seq_data: dict[int, SequenceData] - sampling_params: Optional[SamplingParams] - block_tables: dict[int, list[int]] - do_sample: bool = True - pooling_params: Optional[PoolingParams] = None - lora_request: Optional[LoRARequest] = None - computed_block_nums: Optional[list[int]] = None - state: Optional[SequenceGroupState] = msgspec.field( - default_factory=lambda: SequenceGroupState()) - multi_modal_data: Optional[MultiModalKwargs] = None - multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None - encoder_seq_data: Optional[SequenceData] = None - cross_block_table: Optional[list[int]] = None - token_chunk_size: Optional[int] = None - - ### Stateful fields that are lazily defined. ### - # The number of speculative tokens adopted in this request. - # None means specuative decoding is not used. - # Zero means speculative decoding is disabled for some reasons. - # TODO: We should maintain this states out of the sequence group. - num_speculative_tokens: Optional[int] = None - - def __post_init__(self): - if self.seq_data is not None and self.token_chunk_size is None: - if self.is_prompt: - self.token_chunk_size = next(iter( - self.seq_data.values())).get_len() - else: - self.token_chunk_size = 1 - - @property - def lora_int_id(self) -> int: - return self.lora_request.lora_int_id if self.lora_request else 0 - - # Multi-Step Chunked-Prefill property - @property - def is_single_step_prompt(self) -> bool: - # do_sample is true, only when the token_chunk_size matches the - # num_uncomputed_tokens of the sequence. This indicates that - # the prompt will finish processing in a single `execute_model` - # step. - return self.is_prompt and self.do_sample - - def get_first_seq_id(self) -> int: - # This is an efficient way of fetching the seq_id when - # we know this SequenceGroup has only one sequence. - return next(iter(self.seq_data)) - - def apply_delta(self, - sequence_group_metadata_delta: SequenceGroupMetadataDelta): - for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): - self.seq_data[id].apply_delta(delta) - assert self.request_id == sequence_group_metadata_delta.request_id - self.block_tables = sequence_group_metadata_delta.block_tables - self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size - self.do_sample = sequence_group_metadata_delta.do_sample - self.is_prompt = sequence_group_metadata_delta.is_prompt - - def finish_step(self) -> None: - assert self.state is not None - assert self.state.current_step < self.state.num_steps, \ - f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa - self.state.current_step += 1 - - -class SequenceOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """The model output associated with a sequence. - - Attributes: - parent_seq_id: The ID of the parent sequence (for forking in beam - search). - output_token: The output token ID. - logprobs: The logprobs of the output token. - (Token id -> logP(x_i+1 | x_0, ..., x_i)) - output_embed: Optional output embedding tensor. - """ - parent_seq_id: int - output_token: int - logprobs: dict[int, Logprob] - output_embed: Optional[torch.Tensor] = None - - def __repr__(self) -> str: - output_embed_shape = \ - self.output_embed.shape if self.output_embed is not None else None - return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " - f"output_token={self.output_token}, " - f"output_embed.shape={output_embed_shape}, " - f"logprobs={self.logprobs})") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceOutput): - raise NotImplementedError() - equal = (self.parent_seq_id == other.parent_seq_id - and self.output_token == other.output_token) - log_probs_equal = other.logprobs == self.logprobs - return equal and log_probs_equal - - -class SequenceGroupOutput(ABC): - """The base class for model outputs associated with a sequence group.""" - - @abstractmethod - def __repr__(self) -> str: - pass - - @abstractmethod - def __eq__(self, other: object) -> bool: - pass - - -class CompletionSequenceGroupOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """The model output associated with a completion sequence group.""" - __metaclass__ = SequenceGroupOutput - samples: list[SequenceOutput] - # Prompt logprob for each prompt query token. - prompt_logprobs: Optional[PromptLogprobs] - step_index: Optional[int] = 0 - - def __repr__(self) -> str: - return (f"CompletionSequenceGroupOutput(samples={self.samples}, " - f"prompt_logprobs={self.prompt_logprobs})") - - def __eq__(self, other: object) -> bool: - if not isinstance(other, CompletionSequenceGroupOutput): - raise NotImplementedError() - return (self.samples == other.samples - and self.prompt_logprobs == other.prompt_logprobs) - - class PoolingSequenceGroupOutput( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] array_like=True, # type: ignore[call-arg] ): """The model output associated with a pooling sequence group.""" - __metaclass__ = SequenceGroupOutput # Annotated as Any to be compatible with msgspec # The actual type is in SequenceGroup.pooled_data data: Any @@ -1161,305 +143,9 @@ def __eq__(self, other: object): self.__class__) and self.outputs == other.outputs -def get_all_seq_ids( - seq_group_metadata_list: list[SequenceGroupMetadata]) -> list[int]: - """Given a list of SequenceGroupMetadata, create a list of all - sequence ids. - """ - return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] - - -def get_all_seq_ids_and_request_ids( - seq_group_metadata_list: list[SequenceGroupMetadata] -) -> tuple[list[int], dict[str, set[int]]]: - """Given a list of SequenceGroupMetadata, create a list of all - sequence ids. - """ - seq_ids: list[int] = [] - request_id_seq_ids_mapping: defaultdict[str, set[int]] = defaultdict(set) - for sg in seq_group_metadata_list: - for seq_id in sg.seq_data: - seq_ids.append(seq_id) - request_id_seq_ids_mapping[sg.request_id].add(seq_id) - return seq_ids, request_id_seq_ids_mapping - - -class HiddenStates(msgspec.Struct, array_like=True, - omit_defaults=True): # type: ignore[call-arg] - """Hidden states corresponding to in-progress sequences. - Used in speculative decoding to pass hidden states from - the target model to the proposer model. - - seq_ids are the sequence ids of each entry of the batch - dimension of the hidden_states tensor""" - # Scorer hidden states. For prefill step, it is used for hidden states of - # all tokens, whereas for decode step, it is used for last accepted tokens. - hidden_states: torch.Tensor - # The sequence group metadata list. Only needed for decode step. - seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None - # Scorer hidden states of the 2nd last token proposed by the proposer ( - # irrespective of whether it was accepted or not). Only used for cases when - # last proposed token is accepted (i.e., in case of bonus tokens). For the - # case of no bonus tokens, these are ignored. - second_last_token_hidden_states: Optional[torch.Tensor] = None - - _seq_ids: list[int] = msgspec.field(default_factory=list) - - def __post_init__(self): - if self.seq_group_metadata_list is not None: - assert len(self.seq_group_metadata_list) == len(self.hidden_states) - self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) - - @property - def seq_ids(self) -> list[int]: - return self._seq_ids - - def update(self, - hidden_states: torch.Tensor, - seq_group_metadata_list: list[SequenceGroupMetadata], - second_last_token_hidden_states: Optional[torch.Tensor] = None): - """Update hidden states from target model invocation. Only used for - decode steps""" - assert len(seq_group_metadata_list) == len(hidden_states) - self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) - self.hidden_states = torch.cat([self.hidden_states, hidden_states]) - - if self.second_last_token_hidden_states is not None: - # Adding dummy hidden_states to this to maintain same shape - self.second_last_token_hidden_states = torch.cat([ - self.second_last_token_hidden_states, - torch.zeros_like(hidden_states) - if second_last_token_hidden_states is None else - second_last_token_hidden_states - ]) - - def prune(self, - seq_group_metadata_list: list[SequenceGroupMetadata]) -> None: - """Prune to provided list of sequence ids. Only used for decode steps. - """ - # Currently this prunes all seq_ids not present in - # seq_group_metadata_list which might cause problems where a sequence - # may be "paused" then "resumed" later. This should only prune sequences - # which are confirmed to be aborted. - seq_ids = get_all_seq_ids(seq_group_metadata_list) - # Only keep sequence IDs that exist in self._seq_ids - seq_ids = [seq_id for seq_id in seq_ids if seq_id in self._seq_ids] - if seq_ids != self._seq_ids: - # Batch contents changed - prune removed sequences. - index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] - self.hidden_states = self.hidden_states[index] - if self.second_last_token_hidden_states is not None: - self.second_last_token_hidden_states = self\ - .second_last_token_hidden_states[index] - self._seq_ids = seq_ids - - def expand_with_bonus_tokens( - self, seq_with_bonus_token_in_last_step: set) -> None: - """Expand hidden states for sequences with bonus tokens. This is in - alignment with `MultiStepWorker._expand_execute_model_request`.""" - if self.second_last_token_hidden_states is None \ - or not seq_with_bonus_token_in_last_step: - return - - index = [] - for seq_id in self._seq_ids: - i = self._seq_ids.index(seq_id) - if seq_id in seq_with_bonus_token_in_last_step: - index.append(i + len(self._seq_ids)) - index.append(i) - - self.hidden_states = torch.cat( - [self.hidden_states, self.second_last_token_hidden_states])[index] - - class ExecuteModelRequest( msgspec.Struct, array_like=True, # type: ignore[call-arg] omit_defaults=True): # type: ignore[call-arg] - """The model execution request, containing CPU metadata only. The LLM - engine should create an instance of this class for each request batch.""" - # The sequence group metadata list. - seq_group_metadata_list: list[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]] - # Blocks to swap in. List of CPU -> GPU block number. - blocks_to_swap_in: list[tuple[int, - int]] = msgspec.field(default_factory=list) - # Blocks to swap out. List of GPU -> CPU block number. - blocks_to_swap_out: list[tuple[int, - int]] = msgspec.field(default_factory=list) - # Blocks to copy. Source to dest block. - blocks_to_copy: list[tuple[int, int]] = msgspec.field(default_factory=list) - # Virtual engine ID for pipeline parallel. - virtual_engine: int = 0 - # The number of slots for lookahead decoding. - num_lookahead_slots: int = 0 - # The number of requests in the running queue. - running_queue_size: int = 0 - # Optional hidden states from prior step. - previous_hidden_states: Optional[HiddenStates] = None - # The number of forward steps to run. - num_steps: int = 1 - # Finished request ids since last step. - finished_requests_ids: list[str] = msgspec.field(default_factory=list) - # The last sampled token ids for multi step decoding. - last_sampled_token_ids: Optional[torch.Tensor] = None - # Async callback - async_callback: Optional[Callable] = None - - @property - def is_last_step(self) -> bool: - # TODO(will) make this be able to handle batches with variable number of - # steps - assert len(self.seq_group_metadata_list) > 0 - first_seq_group = self.seq_group_metadata_list[0] - assert first_seq_group.state is not None - return first_seq_group.state.remaining_steps == 1 - - @property - def current_step(self) -> int: - # TODO(will) make this be able to handle batches with variable number of - # steps - assert len(self.seq_group_metadata_list) > 0 - state = self.seq_group_metadata_list[0].state - assert state is not None - return state.current_step - - def clone( - self, seq_group_metadata_list: list[Union[SequenceGroupMetadata, - SequenceGroupMetadataDelta]] - ) -> "ExecuteModelRequest": - """Clone the request with a new sequence group metadata list.""" - return ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=self.blocks_to_swap_in.copy(), - blocks_to_swap_out=self.blocks_to_swap_out.copy(), - blocks_to_copy=self.blocks_to_copy.copy(), - virtual_engine=self.virtual_engine, - num_lookahead_slots=self.num_lookahead_slots, - running_queue_size=self.running_queue_size, - previous_hidden_states=self.previous_hidden_states, - num_steps=self.num_steps, - finished_requests_ids=self.finished_requests_ids, - last_sampled_token_ids=self.last_sampled_token_ids.clone() - if self.last_sampled_token_ids is not None else None, - async_callback=self.async_callback) - - -@dataclass -class SequenceGroupBase: - group_id: str # the original request id before splitting - - assembled_seq_group: Optional[SequenceGroup] = None - - # seq id to a unique index inside this group - seq_id_to_index: dict[str, int] = field(default_factory=dict) - - # seq ids to be finished - to_be_finished: dict[str, SequenceGroup] = field(default_factory=dict) - - # seq id to finished sequences - finished_reqs: dict[str, SequenceGroup] = field(default_factory=dict) - - streaming: bool = False - - output_produced: bool = False - - @staticmethod - def add_request(request_id: str, engine, params, *args, **kwargs): - """When we are ready to add a request with request_id and params - into the engine, we can split the request into multiple requests. - """ - raise NotImplementedError - - def finish_seq(self, seq: SequenceGroup): - """The sequence `seq` finishes, we should record the information. - """ - del self.to_be_finished[seq.request_id] - self.finished_reqs[seq.request_id] = seq - - def maybe_assemble_group( - self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: - """Assemble the sequence group, for producing the final - output, or adding request in the engine again. - """ - raise NotImplementedError - - -class ParallelSampleSequenceGroup(SequenceGroupBase): - - @staticmethod - def add_request(request_id: str, engine, params, **kwargs): - original_params = params - group = ParallelSampleSequenceGroup(request_id) - seqs = [] - for i in range(original_params.n): - request_id_i = f"{request_id}_parallel_sample_{i}" - group.seq_id_to_index[request_id_i] = i - params = original_params.clone() - params.n = 1 - if params.seed is not None: - params.seed += i - seq_group = engine._add_processed_request( - request_id_i, - params=params, - **kwargs, - ) # type: ignore - assert seq_group is not None - engine.seq_id_to_seq_group[request_id_i] = group - group.to_be_finished[request_id_i] = seq_group - seqs.append(seq_group.seqs[0]) - - # for parallel sampling, the `assembled_seq_group` is always - # available, since we have all the sequences ready, and they - # will not change. - group.assembled_seq_group = SequenceGroup( - request_id=request_id, - seqs=seqs, - arrival_time=seq_group.arrival_time, - sampling_params=original_params, - lora_request=seq_group.lora_request, - pooling_params=seq_group.pooling_params, - pooled_data=seq_group.pooled_data, - encoder_seq=seq_group.encoder_seq, - trace_headers=seq_group.trace_headers, - priority=seq_group.priority, - ) - - group.streaming = params.output_kind == RequestOutputKind.DELTA - group.output_produced = False - - def maybe_assemble_group( - self, seq_group: SequenceGroup) -> Optional[SequenceGroup]: - - # in the streaming mode, we will return the assembled sequence - # for the first remaining sequence, and then return None for the - # rest of sequences - if self.streaming: - first_remaining_id = next(iter(self.to_be_finished)) - if seq_group.request_id == first_remaining_id: - return self.assembled_seq_group - return None - - # in the non-streaming mode, we will return the assembled sequence - # when the last sequences finishes, and then return None for the - # rest of the time - if (len(self.to_be_finished) == 1 - and seq_group.request_id in self.to_be_finished - and seq_group.is_finished()): - assert self.assembled_seq_group is not None - params = self.assembled_seq_group.sampling_params - assert isinstance(params, SamplingParams) - if not self.output_produced: - self.output_produced = True - if params._real_n is not None: - # Get the top-n sequences. - n = params._real_n or params.n - seqs = self.assembled_seq_group.seqs - sorting_key = lambda seq: seq.get_cumulative_logprob() - sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) - top_n_seqs = sorted_seqs[:n] - self.assembled_seq_group.seqs = top_n_seqs - return self.assembled_seq_group - if self.output_produced: - return None - return None + # Placeholder. Remove. + pass diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py deleted file mode 100644 index e2d2846a2807..000000000000 --- a/vllm/transformers_utils/detokenizer.py +++ /dev/null @@ -1,162 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -from vllm.logprobs import Logprob -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, SamplingParams, Sequence, - SequenceGroup) - -from .detokenizer_utils import (convert_prompt_ids_to_tokens, - detokenize_incrementally) -from .tokenizer import AnyTokenizer - - -class Detokenizer: - """Provides methods to decode the output of a model into text.""" - - def __init__(self, tokenizer: AnyTokenizer): - self.tokenizer = tokenizer - - def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, - prompt_logprobs: list[Optional[dict[ - int, Logprob]]], - position_offset: int) -> None: - """Decodes the logprobs for the prompt of a sequence group. - - Args: - seq_group: The sequence group to decode. - prompt_logprobs: The logprobs to decode. - position_offset: Offset of the first index of the logprobs - relative to the start of the sequence (for chunked prefill). - - Returns: - The prompt logprobs with the decoded tokens. - """ - prms = seq_group.sampling_params - assert prms is not None - - # We can pick any sequence for the prompt. - seq = seq_group.get_seqs()[0] - # Only prompt, without the generated token. - all_token_ids = seq.get_token_ids() - prompt_token_ids = all_token_ids[:-1] - prefix_offset = 0 - read_offset = 0 - next_iter_prefix_offset = 0 - next_iter_read_offset = 0 - next_iter_tokens: list[str] = [] - prev_tokens = None - - for token_position_in_logprob, prompt_logprobs_for_token in enumerate( - prompt_logprobs): - - # Absolute token position equals the index in the logprobs - # list plus the offset of the entire logprobs list relative - # to the start of the sequence. - token_position = token_position_in_logprob + position_offset - if not prompt_logprobs_for_token: - continue - for token_id, sample_logprob in prompt_logprobs_for_token.items(): - if (sample_logprob.decoded_token is None - and token_id != VLLM_INVALID_TOKEN_ID): - prompt_token_ids_with_token = ( - prompt_token_ids[:token_position] + [token_id]) - (new_tokens, new_text, new_prefix_offset, - new_read_offset) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=prompt_token_ids_with_token, - prev_tokens=prev_tokens, - prefix_offset=prefix_offset, - read_offset=read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms. - spaces_between_special_tokens, - ) - - sample_logprob.decoded_token = new_text - - # Use the offsets & prev tokens corresponding to - # real tokens to ensure detokenization is consistent - # actual with prompt. - if token_id == all_token_ids[token_position]: - next_iter_prefix_offset = new_prefix_offset - next_iter_read_offset = new_read_offset - next_iter_tokens = new_tokens - - # Advance to the next token position. - prefix_offset = next_iter_prefix_offset - read_offset = next_iter_read_offset - if prev_tokens is None: - prev_tokens = next_iter_tokens.copy() - else: - prev_tokens.extend(next_iter_tokens) - - def decode_sequence_inplace(self, seq: Sequence, - prms: SamplingParams) -> int: - """Decodes the new token for a sequence. In-place operation. - - Args: - seq: The sequence to decode. - prms: The sampling parameters used to generate the sequence. - - Returns: - The number of characters added to the output text. - """ - all_input_ids = seq.get_token_ids() - token_id_generated_this_iteration = all_input_ids[-1] - - # Convert prompt token IDs to tokens if necessary. - # Do it here so that we don't have to repeat this - # computation for each logprob. - if seq.tokens is None: - (seq.tokens, seq.prefix_offset, - seq.read_offset) = convert_prompt_ids_to_tokens( - tokenizer=self.tokenizer, - prompt_ids=all_input_ids[:-1], - skip_special_tokens=prms.skip_special_tokens, - ) - - (new_tokens, new_decoded_token_text, prefix_offset, - read_offset) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=all_input_ids, - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms.spaces_between_special_tokens, - ) - - # Decode logprobs - logprobs = seq.output_logprobs[-1] - if logprobs: - previous_tokens = all_input_ids[:-1] - for token_id, sample_logprob in logprobs.items(): - # If the token was generated this iteration, - # use the provided text. - if token_id == token_id_generated_this_iteration: - sample_logprob.decoded_token = new_decoded_token_text - continue - - if (sample_logprob.decoded_token is None - and token_id != VLLM_INVALID_TOKEN_ID): - all_input_ids_with_logprob = previous_tokens + [token_id] - (_, new_text, _, _) = detokenize_incrementally( - tokenizer=self.tokenizer, - all_input_ids=all_input_ids_with_logprob, - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms. - spaces_between_special_tokens, - ) - sample_logprob.decoded_token = new_text - - seq.tokens.extend(new_tokens) - seq.prefix_offset = prefix_offset - seq.read_offset = read_offset - seq.output_text += new_decoded_token_text - - return len(new_decoded_token_text) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index eaab976bf7f7..20fabef4f19b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -11,12 +11,12 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, run_method, update_environment_variables, warn_for_unimplemented_methods) +from vllm.v1.outputs import SamplerOutput logger = init_logger(__name__) From 66b1e08ee6034bad3933c6014ec29d96f6978aac Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Sep 2025 08:52:32 -0700 Subject: [PATCH 06/17] [V0 Deprecation] Remove async_output_proc, preemption mode, delay factor (#25334) Signed-off-by: Woosuk Kwon --- tests/detokenizer/test_stop_strings.py | 46 ++---------------- .../test_processor_multi_modal_uuids.py | 10 ---- tests/v1/test_oracle.py | 18 ------- vllm/config/__init__.py | 4 -- vllm/config/model.py | 48 +++---------------- vllm/config/scheduler.py | 15 +----- vllm/engine/arg_utils.py | 34 ------------- vllm/entrypoints/llm.py | 4 -- vllm/executor/uniproc_executor.py | 4 -- vllm/platforms/cpu.py | 4 -- vllm/platforms/cuda.py | 10 ---- vllm/platforms/interface.py | 7 --- vllm/platforms/rocm.py | 10 ---- vllm/platforms/tpu.py | 4 -- vllm/platforms/xpu.py | 4 -- 15 files changed, 12 insertions(+), 210 deletions(-) diff --git a/tests/detokenizer/test_stop_strings.py b/tests/detokenizer/test_stop_strings.py index cb87c44cc399..46f7d58c438c 100644 --- a/tests/detokenizer/test_stop_strings.py +++ b/tests/detokenizer/test_stop_strings.py @@ -32,10 +32,6 @@ def _test_stopping(llm: LLM, assert output.stop_reason == expected_reason -def _set_async_mode(llm, is_async): - llm.llm_engine.scheduler[0].use_async_output_proc = is_async - - def _stop_basic(llm): _test_stopping(llm, stop=["."], @@ -103,40 +99,8 @@ def test_stop_strings(): # async output processing below. llm = LLM(MODEL, enforce_eager=envs.VLLM_USE_V1) - if envs.VLLM_USE_V1: - _stop_basic(llm) - else: - _set_async_mode(llm, True) - _stop_basic(llm) - - _set_async_mode(llm, False) - _stop_basic(llm) - - if envs.VLLM_USE_V1: - _stop_multi_tokens(llm) - else: - _set_async_mode(llm, True) - _stop_multi_tokens(llm) - - _set_async_mode(llm, False) - _stop_multi_tokens(llm) - - if envs.VLLM_USE_V1: - _stop_partial_token(llm) - else: - _set_async_mode(llm, True) - _stop_partial_token(llm) - - _set_async_mode(llm, False) - _stop_partial_token(llm) - - if envs.VLLM_USE_V1: - # FIXME: this does not respect include_in_output=False - # _stop_token_id(llm) - pass - else: - _set_async_mode(llm, True) - _stop_token_id(llm) - - _set_async_mode(llm, False) - _stop_token_id(llm) + _stop_basic(llm) + _stop_multi_tokens(llm) + _stop_partial_token(llm) + # FIXME: this does not respect include_in_output=False + # _stop_token_id(llm) diff --git a/tests/v1/engine/test_processor_multi_modal_uuids.py b/tests/v1/engine/test_processor_multi_modal_uuids.py index bdd41eece231..3a7bcb957182 100644 --- a/tests/v1/engine/test_processor_multi_modal_uuids.py +++ b/tests/v1/engine/test_processor_multi_modal_uuids.py @@ -6,7 +6,6 @@ from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig -from vllm.platforms.interface import UnspecifiedPlatform from vllm.sampling_params import SamplingParams from vllm.v1.engine import processor as processor_mod from vllm.v1.engine.processor import Processor @@ -33,15 +32,6 @@ def _mk_processor(monkeypatch, "__post_init__", lambda self, *args: None, raising=True) - monkeypatch.setattr(UnspecifiedPlatform, - "is_async_output_supported", - classmethod(lambda cls, enforce_eager: True), - raising=True) - monkeypatch.setattr( - ModelConfig, - "verify_async_output_proc", - lambda self, parallel_config, speculative_config, device_config: None, - raising=True) monkeypatch.setattr(ModelConfig, "verify_with_parallel_config", lambda self, parallel_config: None, diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 28c24f62895a..f6b8a18dd7c2 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -29,24 +29,6 @@ def test_unsupported_configs(monkeypatch): }, ).create_engine_config() - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - preemption_mode="swap", - ).create_engine_config() - - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - disable_async_output_proc=True, - ).create_engine_config() - - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - scheduler_delay_factor=1.2, - ).create_engine_config() - def test_enable_by_default_fallback(monkeypatch): with monkeypatch.context() as m: diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index ddd8de4324f6..e31a78ba33ba 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -454,9 +454,6 @@ def __post_init__(self): self.try_verify_and_update_config() if self.model_config is not None: - self.model_config.verify_async_output_proc(self.parallel_config, - self.speculative_config, - self.device_config) self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_dual_chunk_attention_config( self.load_config) @@ -877,7 +874,6 @@ def __str__(self): f"served_model_name={self.model_config.served_model_name}, " f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa - f"use_async_output_proc={self.model_config.use_async_output_proc}, " f"pooler_config={self.model_config.pooler_config!r}, " f"compilation_config={self.compilation_config!r}") diff --git a/vllm/config/model.py b/vllm/config/model.py index 921322bb475c..b53029dc8c3e 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -223,8 +223,6 @@ class ModelConfig: that this name(s) will also be used in `model_name` tag content of prometheus metrics, if multiple names provided, metrics tag will take the first one.""" - use_async_output_proc: bool = True - """Whether to use async output processor.""" config_format: Union[str, ConfigFormat] = "auto" """The format of the model config to load:\n - "auto" will try to load the config in hf format if available else it @@ -1119,37 +1117,6 @@ def verify_dual_chunk_attention_config( raise ValueError("please set VLLM_ATTENTION_BACKEND to " f"{STR_DUAL_CHUNK_FLASH_ATTN_VAL}") - def verify_async_output_proc(self, parallel_config, speculative_config, - device_config) -> None: - if not self.use_async_output_proc: - # Nothing to check - return - - if parallel_config.pipeline_parallel_size > 1: - self.use_async_output_proc = False - return - - # Reminder: Please update docs/features/compatibility_matrix.md - # If the feature combo become valid - from vllm.platforms import current_platform - if not current_platform.is_async_output_supported(self.enforce_eager): - self.use_async_output_proc = False - return - - if envs.VLLM_USE_RAY_SPMD_WORKER: - self.use_async_output_proc = False - return - - # Async postprocessor is not necessary for pooling models - # since there is no token generation - if self.runner_type == "pooling": - self.use_async_output_proc = False - - # Reminder: Please update docs/features/compatibility_matrix.md - # If the feature combo become valid - if speculative_config: - self.use_async_output_proc = False - def verify_with_parallel_config( self, parallel_config: ParallelConfig, @@ -1173,15 +1140,12 @@ def verify_with_parallel_config( self._verify_with_expert_parallelism() pipeline_parallel_size = parallel_config.pipeline_parallel_size - if pipeline_parallel_size > 1: - if not self.registry.is_pp_supported_model(self.architectures, - self): - raise NotImplementedError( - "Pipeline parallelism is not supported for this model. " - "Supported models implement the `SupportsPP` interface.") - - if self.use_async_output_proc: - self.use_async_output_proc = False + if (pipeline_parallel_size > 1 + and not self.registry.is_pp_supported_model( + self.architectures, self)): + raise NotImplementedError( + "Pipeline parallelism is not supported for this model. " + "Supported models implement the `SupportsPP` interface.") def get_sliding_window(self) -> Optional[int]: """Get the sliding window size from the HF text config if present.""" diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index f0f67bab9d6f..daf094d2df5c 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -3,7 +3,7 @@ import hashlib from dataclasses import field -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Union from pydantic import SkipValidation, model_validator from pydantic.dataclasses import dataclass @@ -18,7 +18,6 @@ logger = init_logger(__name__) RunnerType = Literal["generate", "pooling", "draft"] -PreemptionMode = Literal["swap", "recompute"] SchedulerPolicy = Literal["fcfs", "priority"] @@ -78,10 +77,6 @@ class SchedulerConfig: 3. more than one value (e.g. 1 2 128) is provided, then the capture list will follow the provided list.""" - delay_factor: float = 0.0 - """Apply a delay (of delay factor multiplied by previous - prompt latency) before scheduling next prompt.""" - enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" @@ -103,14 +98,6 @@ class SchedulerConfig: NOTE: This is not currently configurable. It will be overridden by max_num_batched_tokens in case max multimodal embedding size is larger.""" - preemption_mode: Optional[PreemptionMode] = None - """Whether to perform preemption by swapping or - recomputation. If not specified, we determine the mode as follows: - We use recomputation by default since it incurs lower overhead than - swapping. However, when the sequence group has multiple sequences - (e.g., beam search), recomputation is not currently supported. In - such a case, we use swapping instead.""" - send_delta_data: bool = False """Private API. If used, scheduler sends delta data to workers instead of an entire data. It should be enabled only diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 242fcf501bfc..fef4177b3a33 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -409,9 +409,7 @@ class EngineArgs: get_field(LoadConfig, "model_loader_extra_config") ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns - preemption_mode: Optional[str] = SchedulerConfig.preemption_mode - scheduler_delay_factor: float = SchedulerConfig.delay_factor enable_chunked_prefill: Optional[ bool] = SchedulerConfig.enable_chunked_prefill disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input @@ -439,7 +437,6 @@ class EngineArgs: ObservabilityConfig.otlp_traces_endpoint collect_detailed_traces: Optional[list[DetailedTraceModules]] = \ ObservabilityConfig.collect_detailed_traces - disable_async_output_proc: bool = not ModelConfig.use_async_output_proc scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls @@ -561,14 +558,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **model_kwargs["enable_prompt_embeds"]) model_group.add_argument("--served-model-name", **model_kwargs["served_model_name"]) - # This one is a special case because it is the - # opposite of ModelConfig.use_async_output_proc - model_group.add_argument( - "--disable-async-output-proc", - action="store_true", - default=EngineArgs.disable_async_output_proc, - help="Disable async output processing. This may result in " - "lower performance.") model_group.add_argument("--config-format", **model_kwargs["config_format"]) # This one is a special case because it can bool @@ -897,10 +886,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **scheduler_kwargs["long_prefill_token_threshold"]) scheduler_group.add_argument("--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"]) - scheduler_group.add_argument("--scheduler-delay-factor", - **scheduler_kwargs["delay_factor"]) - scheduler_group.add_argument("--preemption-mode", - **scheduler_kwargs["preemption_mode"]) # multi-step scheduling has been removed; corresponding arguments # are no longer supported. scheduler_group.add_argument("--scheduling-policy", @@ -1029,7 +1014,6 @@ def create_model_config(self) -> ModelConfig: interleave_mm_strings=self.interleave_mm_strings, media_io_kwargs=self.media_io_kwargs, skip_mm_profiling=self.skip_mm_profiling, - use_async_output_proc=not self.disable_async_output_proc, config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, @@ -1395,11 +1379,9 @@ def create_engine_config( max_model_len=model_config.max_model_len, cuda_graph_sizes=self.cuda_graph_sizes, num_lookahead_slots=num_lookahead_slots, - delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, disable_chunked_mm_input=self.disable_chunked_mm_input, is_multimodal_model=model_config.is_multimodal_model, - preemption_mode=self.preemption_mode, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, @@ -1492,22 +1474,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - if self.preemption_mode != SchedulerConfig.preemption_mode: - _raise_or_fallback(feature_name="--preemption-mode", - recommend_to_remove=True) - return False - - if (self.disable_async_output_proc - != EngineArgs.disable_async_output_proc): - _raise_or_fallback(feature_name="--disable-async-output-proc", - recommend_to_remove=True) - return False - - if self.scheduler_delay_factor != SchedulerConfig.delay_factor: - _raise_or_fallback(feature_name="--scheduler-delay-factor", - recommend_to_remove=True) - return False - # No Mamba or Encoder-Decoder so far. if not model_config.is_v1_compatible: _raise_or_fallback(feature_name=model_config.architectures, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0ab806fcb8b5..092d3f276d1c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -137,8 +137,6 @@ class LLM: back to the eager mode. disable_custom_all_reduce: See [ParallelConfig][vllm.config.ParallelConfig]. - disable_async_output_proc: Disable async output processing. - This may result in lower performance. hf_token: The token to use as HTTP bearer authorization for remote files . If `True`, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). @@ -188,7 +186,6 @@ def __init__( enforce_eager: bool = False, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, - disable_async_output_proc: bool = False, hf_token: Optional[Union[bool, str]] = None, hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None, @@ -286,7 +283,6 @@ def __init__( enforce_eager=enforce_eager, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, - disable_async_output_proc=disable_async_output_proc, hf_token=hf_token, hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 3b566e88a9ec..7a753d608a43 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -137,10 +137,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor): def _init_executor(self) -> None: """Initialize the worker and load the model. """ - assert self.vllm_config.scheduler_config.delay_factor == 0.0, \ - ("ExecutorWithExternalLauncher needs deterministic " - "execution, so it" - "does not support delay_factor in scheduling") if envs.VLLM_USE_V1: assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \ ("To get deterministic execution in V1, " diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 544e091491bf..cd41832bc2ea 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -126,10 +126,6 @@ def set_device(cls, device: torch.device) -> None: """ torch.cpu.set_device(device) - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return False - @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 87d8f2b7481b..c263e2afe83b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -96,16 +96,6 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - if enforce_eager and not envs.VLLM_USE_V1: - logger.warning( - "To see benefits of async output processing, enable CUDA " - "graph. Since, enforce-eager is enabled, async output " - "processor cannot be used") - return False - return True - @classmethod def is_fully_connected(cls, device_ids: list[int]) -> bool: raise NotImplementedError diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 53fc762dce54..c43580ac5da1 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -275,13 +275,6 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: """Get the total memory of a device in bytes.""" raise NotImplementedError - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - """ - Check if the current platform supports async output. - """ - raise NotImplementedError - @classmethod def inference_mode(cls): """A device-specific wrapper of `torch.inference_mode`. diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 4f540fe965e2..dce2924ac7a9 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -310,16 +310,6 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.cuda.get_device_properties(device_id) return device_props.total_memory - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - if enforce_eager and not envs.VLLM_USE_V1: - logger.warning( - "To see benefits of async output processing, enable CUDA " - "graph. Since, enforce-eager is enabled, async output " - "processor cannot be used") - return False - return True - @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: from vllm.config.compilation import CUDAGraphMode diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 4e4db116abca..9852d948bc4b 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -75,10 +75,6 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return False - @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 67ef058df10f..4d3bef4b4294 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -98,10 +98,6 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.xpu.get_device_properties(device_id) return device_props.total_memory - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return True - @classmethod def inference_mode(cls): return torch.no_grad() From 0de3fac1256801f4733eb3cf4d79b1302bac996d Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Sun, 21 Sep 2025 22:34:45 +0530 Subject: [PATCH 07/17] feat: Enable engine-level arguments with speculators models (#25250) Signed-off-by: Rahul Tuli Co-authored-by: Claude --- .../speculators/test_eagle3.py | 54 ++++++++++++------- vllm/config/model.py | 12 +---- vllm/engine/arg_utils.py | 35 +++++------- vllm/transformers_utils/config.py | 46 +++++++++++++--- .../configs/speculators/base.py | 52 ++++++++++++------ 5 files changed, 121 insertions(+), 78 deletions(-) diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index 45ddb2178722..368238b3a720 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -3,38 +3,52 @@ import pytest import torch +from vllm.config import SpeculativeConfig from vllm.model_executor.models.interfaces import supports_eagle3 -@pytest.mark.parametrize( - "model_path", - [("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")]) -def test_llama(vllm_runner, example_prompts, model_path, monkeypatch): +@pytest.mark.parametrize("model_path", [ + pytest.param( + "nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized", + id="llama3-eagle3-speculator"), + pytest.param( + "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized", + id="qwen3-eagle3-speculator"), +]) +def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path, + monkeypatch): + """ + Test Eagle3 speculators models properly initialize speculative decoding. + + This test verifies: + 1. Eagle3 support is detected for the model + 2. Speculative config is automatically initialized from embedded config + 3. The draft model path is correctly set to the speculators model + 4. Speculative tokens count is valid + 5. Text generation works with speculative decoding enabled + """ # Set environment variable for V1 engine serialization monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: + # Verify Eagle3 support is detected eagle3_supported = vllm_model.apply_model(supports_eagle3) - assert eagle3_supported + assert eagle3_supported, f"Eagle3 should be supported for {model_path}" - vllm_outputs = vllm_model.generate_greedy(example_prompts, - max_tokens=20) - print(vllm_outputs) - assert vllm_outputs + vllm_config = vllm_model.llm.llm_engine.vllm_config + assert isinstance(vllm_config.speculative_config, SpeculativeConfig), \ + "Speculative config should be initialized for speculators model" -@pytest.mark.parametrize( - "model_path", - [("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")]) -def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch): - # Set environment variable for V1 engine serialization - monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + spec_config = vllm_config.speculative_config + assert spec_config.num_speculative_tokens > 0, \ + (f"Expected positive speculative tokens, " + f"got {spec_config.num_speculative_tokens}") - with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: - eagle3_supported = vllm_model.apply_model(supports_eagle3) - assert eagle3_supported + assert spec_config.model == model_path, \ + f"Draft model should be {model_path}, got {spec_config.model}" vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20) - print(vllm_outputs) - assert vllm_outputs + assert vllm_outputs, \ + f"No outputs generated for speculators model {model_path}" diff --git a/vllm/config/model.py b/vllm/config/model.py index b53029dc8c3e..95fe52883db0 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -27,8 +27,7 @@ ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, - is_interleaved, maybe_override_with_speculators_target_model, - try_get_generation_config, try_get_safetensors_metadata, + is_interleaved, try_get_generation_config, try_get_safetensors_metadata, try_get_tokenizer_config, uses_mrope) from vllm.transformers_utils.runai_utils import (ObjectStorageModel, is_runai_obj_uri) @@ -416,15 +415,6 @@ def __post_init__( self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) - if self.runner != "draft": - # If we're not running the draft model, check for speculators config - # If speculators config, set model / tokenizer to be target model - self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501 - model=self.model, - tokenizer=self.tokenizer, - revision=self.revision, - trust_remote_code=self.trust_remote_code) - if (backend := envs.VLLM_ATTENTION_BACKEND ) and backend == "FLASHINFER" and find_spec("flashinfer") is None: raise ValueError( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fef4177b3a33..7e00260caa39 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -41,7 +41,8 @@ from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 -from vllm.transformers_utils.config import get_model_path, is_interleaved +from vllm.transformers_utils.config import (get_model_path, is_interleaved, + maybe_override_with_speculators) from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor) @@ -1082,29 +1083,8 @@ def create_speculative_config( provided as a JSON string input via CLI arguments or directly as a dictionary from the engine. """ - - from vllm.transformers_utils.config import get_config - from vllm.transformers_utils.configs.speculators.base import ( - SpeculatorsConfig) - if self.speculative_config is None: - hf_config = get_config( - self.hf_config_path or target_model_config.model, - self.trust_remote_code, self.revision, self.code_revision, - self.config_format) - - # if loading a SpeculatorsConfig, load the speculative_config - # details from the config directly - # no user input required / expected - if isinstance(hf_config, SpeculatorsConfig): - # We create one since we don't create one - self.speculative_config = {} - self.speculative_config[ - "num_speculative_tokens"] = hf_config.num_lookahead_tokens - self.speculative_config["model"] = target_model_config.model - self.speculative_config["method"] = hf_config.method - else: - return None + return None # Note(Shangming): These parameters are not obtained from the cli arg # '--speculative-config' and must be passed in when creating the engine @@ -1139,6 +1119,15 @@ def create_engine_config( device_config = DeviceConfig( device=cast(Device, current_platform.device_type)) + + (self.model, self.tokenizer, + self.speculative_config) = maybe_override_with_speculators( + model=self.model, + tokenizer=self.tokenizer, + revision=self.revision, + trust_remote_code=self.trust_remote_code, + vllm_speculative_config=self.speculative_config, + ) model_config = self.create_model_config() # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 52e2c18a7784..9eed46678866 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -463,15 +463,29 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig: return config -def maybe_override_with_speculators_target_model( +def maybe_override_with_speculators( model: str, tokenizer: str, trust_remote_code: bool, revision: Optional[str] = None, + vllm_speculative_config: Optional[dict[str, Any]] = None, **kwargs, -) -> tuple[str, str]: +) -> tuple[str, str, Optional[dict[str, Any]]]: """ - If running a speculators config, override running model with target model + Resolve model configuration when speculators are detected. + + Checks if the provided model is a speculators model and if so, extracts + the target model configuration and builds the speculative config. + + Args: + model: Model name or path + tokenizer: Tokenizer name or path + trust_remote_code: Whether to trust remote code + revision: Model revision + vllm_speculative_config: Existing vLLM speculative config + + Returns: + Tuple of (resolved_model, resolved_tokenizer, speculative_config) """ is_gguf = check_gguf_file(model) if is_gguf: @@ -487,11 +501,27 @@ def maybe_override_with_speculators_target_model( token=_get_hf_token(), **kwargs, ) - spec_config = config_dict.get("speculators_config", None) - # Return the target model - if spec_config is not None: - model = tokenizer = spec_config["verifier"]["name_or_path"] - return model, tokenizer + speculators_config = config_dict.get("speculators_config") + + if speculators_config is None: + # No speculators config found, return original values + return model, tokenizer, vllm_speculative_config + + # Speculators format detected - process overrides + from vllm.transformers_utils.configs.speculators.base import ( + SpeculatorsConfig) + + vllm_speculative_config = SpeculatorsConfig.extract_vllm_speculative_config( + config_dict=config_dict) + + # Set the draft model to the speculators model + vllm_speculative_config["model"] = model + + # Override model and tokenizer with the verifier model from config + verifier_model = speculators_config["verifier"]["name_or_path"] + model = tokenizer = verifier_model + + return model, tokenizer, vllm_speculative_config def get_config( diff --git a/vllm/transformers_utils/configs/speculators/base.py b/vllm/transformers_utils/configs/speculators/base.py index d7c16e180c70..53128b4eecb0 100644 --- a/vllm/transformers_utils/configs/speculators/base.py +++ b/vllm/transformers_utils/configs/speculators/base.py @@ -24,6 +24,12 @@ def from_pretrained( config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + vllm_config = cls.extract_vllm_speculative_config(config_dict) + return cls(**vllm_config) + + @classmethod + def extract_vllm_speculative_config( + cls, config_dict: dict[str, Any]) -> dict[str, Any]: speculators_model_type = config_dict.get("speculators_model_type") if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES: raise ValueError( @@ -34,11 +40,12 @@ def from_pretrained( # TODO: @dsikka - use speculators pydantic model to validate cls.validate_speculators_config(config_dict=config_dict) # Convert from speculators config -> format that can be ingested by vLLM - vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict) + vllm_config = cls.build_vllm_speculative_config( + config_dict=config_dict) # Apply anything specific to the supported algorithm algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type] algo_updater(config_dict=config_dict, vllm_config=vllm_config) - return cls(**vllm_config) + return vllm_config @classmethod def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None: @@ -60,32 +67,45 @@ def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None: "'transformer_layer_config' must be a dictionary if provided") @classmethod - def convert_speculators_to_vllm( + def build_vllm_speculative_config( cls, config_dict: dict[str, Any]) -> dict[str, Any]: """ - Convert speculators config format to vLLM format. - - This method handles the translation of field names and structure - between speculators and vLLM formats. - + Build vLLM-compatible speculative configuration from speculators format. + + This method extracts and transforms speculative configuration from the + speculators format into the structure expected by vLLM. + + Args: + config_dict: Configuration dictionary in speculators format + Returns: - Dictionary with vLLM-compatible configuration + Dictionary with vLLM-compatible speculative configuration """ - # Currently we only support one proposal method + # Extract speculators configuration spec_config = config_dict["speculators_config"] - first_method = spec_config.get("proposal_methods")[0] - num_lookahead_tokens = first_method.get("speculative_tokens") - if num_lookahead_tokens is None: + # Currently we only support one proposal method + proposal_methods = spec_config.get("proposal_methods") + if not proposal_methods: + raise ValueError("No proposal methods found in speculators config") + + first_method = proposal_methods[0] + num_speculative_tokens = first_method.get("speculative_tokens") + + if num_speculative_tokens is None: raise ValueError( "Missing 'speculative_tokens' in proposal method. " f"Got: {first_method}") - # Build base vLLM config + # Build base vLLM speculative configuration vllm_config = { "method": config_dict.get("speculators_model_type"), - "num_lookahead_tokens": num_lookahead_tokens, + "num_speculative_tokens": num_speculative_tokens, "target_model": spec_config.get("verifier")["name_or_path"] } - vllm_config.update(config_dict["transformer_layer_config"]) + + # Merge transformer layer configuration if present + transformer_config = config_dict.get("transformer_layer_config", {}) + vllm_config.update(transformer_config) + return vllm_config From 69a760186155ba9ceeac0e189721809d09c65c9a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Sep 2025 10:37:11 -0700 Subject: [PATCH 08/17] [V0 Deprecation] Remove V0 sampling metadata (#25345) Signed-off-by: Woosuk Kwon --- .../vllm_add_dummy_model/my_llava.py | 8 +++----- .../vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py | 8 +++----- vllm/model_executor/__init__.py | 2 -- vllm/model_executor/layers/logits_processor.py | 2 -- vllm/model_executor/models/apertus.py | 5 +---- vllm/model_executor/models/arcee.py | 7 +++---- vllm/model_executor/models/arctic.py | 5 +---- vllm/model_executor/models/aria.py | 7 ++----- vllm/model_executor/models/aya_vision.py | 5 +---- vllm/model_executor/models/baichuan.py | 5 +---- vllm/model_executor/models/bailing_moe.py | 5 +---- vllm/model_executor/models/bamba.py | 5 +---- vllm/model_executor/models/blip2.py | 5 +---- vllm/model_executor/models/bloom.py | 5 +---- vllm/model_executor/models/chameleon.py | 5 +---- vllm/model_executor/models/chatglm.py | 5 +---- vllm/model_executor/models/cohere2_vision.py | 5 +---- vllm/model_executor/models/commandr.py | 6 ++---- vllm/model_executor/models/dbrx.py | 5 +---- vllm/model_executor/models/deepseek.py | 5 +---- vllm/model_executor/models/deepseek_eagle.py | 5 +---- vllm/model_executor/models/deepseek_mtp.py | 9 ++------- vllm/model_executor/models/deepseek_v2.py | 5 +---- vllm/model_executor/models/deepseek_vl2.py | 5 +---- vllm/model_executor/models/dots1.py | 5 +---- vllm/model_executor/models/ernie45_moe.py | 5 +---- vllm/model_executor/models/ernie45_vl.py | 5 +---- vllm/model_executor/models/ernie45_vl_moe.py | 5 +---- vllm/model_executor/models/ernie_mtp.py | 8 ++------ vllm/model_executor/models/exaone.py | 5 +---- vllm/model_executor/models/exaone4.py | 5 +---- vllm/model_executor/models/falcon.py | 5 +---- vllm/model_executor/models/falcon_h1.py | 5 +---- vllm/model_executor/models/fuyu.py | 4 +--- vllm/model_executor/models/gemma.py | 5 +---- vllm/model_executor/models/gemma2.py | 5 +---- vllm/model_executor/models/gemma3.py | 5 +---- vllm/model_executor/models/gemma3_mm.py | 5 +---- vllm/model_executor/models/gemma3n.py | 5 +---- vllm/model_executor/models/gemma3n_mm.py | 5 +---- vllm/model_executor/models/glm4.py | 5 +---- vllm/model_executor/models/glm4_1v.py | 5 +---- vllm/model_executor/models/glm4_moe.py | 5 +---- vllm/model_executor/models/glm4_moe_mtp.py | 9 ++------- vllm/model_executor/models/gpt2.py | 5 +---- vllm/model_executor/models/gpt_bigcode.py | 5 +---- vllm/model_executor/models/gpt_j.py | 4 +--- vllm/model_executor/models/gpt_neox.py | 5 +---- vllm/model_executor/models/gpt_oss.py | 7 ++----- vllm/model_executor/models/granite.py | 9 +++------ vllm/model_executor/models/granite_speech.py | 7 +------ vllm/model_executor/models/granitemoe.py | 9 +++------ vllm/model_executor/models/granitemoehybrid.py | 5 +---- vllm/model_executor/models/granitemoeshared.py | 9 +++------ vllm/model_executor/models/grok1.py | 5 +---- vllm/model_executor/models/hunyuan_v1.py | 5 +---- vllm/model_executor/models/hyperclovax_vision.py | 5 +---- vllm/model_executor/models/idefics3.py | 7 ++----- vllm/model_executor/models/interfaces_base.py | 3 --- vllm/model_executor/models/internlm2.py | 5 +---- vllm/model_executor/models/interns1.py | 5 +---- vllm/model_executor/models/internvl.py | 5 +---- vllm/model_executor/models/jais.py | 5 +---- vllm/model_executor/models/jamba.py | 5 +---- vllm/model_executor/models/keye.py | 5 +---- vllm/model_executor/models/kimi_vl.py | 5 +---- vllm/model_executor/models/lfm2.py | 7 ++----- vllm/model_executor/models/llama.py | 5 +---- vllm/model_executor/models/llama_eagle3.py | 5 +---- vllm/model_executor/models/llava.py | 5 +---- vllm/model_executor/models/llava_next.py | 5 +---- vllm/model_executor/models/llava_next_video.py | 5 +---- vllm/model_executor/models/llava_onevision.py | 5 +---- vllm/model_executor/models/mamba.py | 7 ++----- vllm/model_executor/models/mamba2.py | 7 ++----- vllm/model_executor/models/medusa.py | 3 +-- vllm/model_executor/models/midashenglm.py | 4 +--- vllm/model_executor/models/mimo.py | 5 +---- vllm/model_executor/models/mimo_mtp.py | 8 ++------ vllm/model_executor/models/minicpm.py | 5 +---- vllm/model_executor/models/minicpm_eagle.py | 5 +---- vllm/model_executor/models/minicpmv.py | 4 +--- vllm/model_executor/models/minimax_text_01.py | 7 ++----- vllm/model_executor/models/minimax_vl_01.py | 5 +---- vllm/model_executor/models/mistral3.py | 5 +---- vllm/model_executor/models/mixtral.py | 5 +---- vllm/model_executor/models/mllama4.py | 5 +---- vllm/model_executor/models/molmo.py | 7 ++----- vllm/model_executor/models/mpt.py | 5 +---- vllm/model_executor/models/nano_nemotron_vl.py | 5 +---- vllm/model_executor/models/nemotron.py | 5 +---- vllm/model_executor/models/nemotron_h.py | 5 +---- vllm/model_executor/models/nemotron_nas.py | 5 +---- vllm/model_executor/models/nemotron_vl.py | 5 +---- vllm/model_executor/models/olmo.py | 5 +---- vllm/model_executor/models/olmo2.py | 5 +---- vllm/model_executor/models/olmoe.py | 7 ++----- vllm/model_executor/models/opt.py | 5 +---- vllm/model_executor/models/orion.py | 5 +---- vllm/model_executor/models/ovis.py | 4 +--- vllm/model_executor/models/ovis2_5.py | 4 +--- vllm/model_executor/models/paligemma.py | 5 +---- vllm/model_executor/models/persimmon.py | 5 +---- vllm/model_executor/models/phi.py | 4 +--- vllm/model_executor/models/phi3v.py | 5 +---- vllm/model_executor/models/phi4_multimodal.py | 5 +---- vllm/model_executor/models/phi4flash.py | 3 --- vllm/model_executor/models/phi4mm.py | 5 +---- vllm/model_executor/models/phimoe.py | 7 ++----- vllm/model_executor/models/pixtral.py | 5 +---- vllm/model_executor/models/plamo2.py | 5 +---- vllm/model_executor/models/qwen.py | 5 +---- vllm/model_executor/models/qwen2.py | 5 +---- vllm/model_executor/models/qwen2_5_omni_thinker.py | 5 +---- vllm/model_executor/models/qwen2_5_vl.py | 5 +---- vllm/model_executor/models/qwen2_audio.py | 5 +---- vllm/model_executor/models/qwen2_moe.py | 5 +---- vllm/model_executor/models/qwen2_vl.py | 5 +---- vllm/model_executor/models/qwen3.py | 5 +---- vllm/model_executor/models/qwen3_moe.py | 5 +---- vllm/model_executor/models/qwen3_next.py | 5 +---- vllm/model_executor/models/qwen3_next_mtp.py | 5 +---- vllm/model_executor/models/qwen3_vl.py | 5 +---- vllm/model_executor/models/seed_oss.py | 5 +---- vllm/model_executor/models/skyworkr1v.py | 5 +---- vllm/model_executor/models/solar.py | 7 ++----- vllm/model_executor/models/stablelm.py | 5 +---- vllm/model_executor/models/starcoder2.py | 5 +---- vllm/model_executor/models/step3_text.py | 7 ++----- vllm/model_executor/models/step3_vl.py | 5 +---- vllm/model_executor/models/tarsier.py | 5 +---- vllm/model_executor/models/transformers.py | 5 +---- vllm/model_executor/models/ultravox.py | 7 ++----- vllm/model_executor/models/voxtral.py | 5 +---- vllm/model_executor/models/whisper.py | 7 ++----- vllm/model_executor/models/zamba2.py | 5 +---- vllm/model_executor/sampling_metadata.py | 7 ------- vllm/v1/spec_decode/eagle.py | 9 +++------ vllm/v1/spec_decode/medusa.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 9 ++++----- vllm/v1/worker/tpu_model_runner.py | 2 +- 141 files changed, 172 insertions(+), 583 deletions(-) delete mode 100644 vllm/model_executor/sampling_metadata.py diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index da97cf7e2b40..b431ad1ed092 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -9,7 +9,6 @@ LlavaForConditionalGeneration, LlavaMultiModalProcessor, LlavaProcessingInfo) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -18,11 +17,10 @@ dummy_inputs=LlavaDummyInputsBuilder) class MyLlava(LlavaForConditionalGeneration): - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # this dummy model always predicts the first token - logits = super().compute_logits(hidden_states, sampling_metadata) + logits = super().compute_logits(hidden_states) if logits is not None: logits.zero_() logits[:, 0] += 1.0 diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py index 8c34407e3e07..a6fafff98e9c 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py @@ -6,16 +6,14 @@ import torch from vllm.model_executor.models.opt import OPTForCausalLM -from vllm.model_executor.sampling_metadata import SamplingMetadata class MyOPTForCausalLM(OPTForCausalLM): - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # this dummy model always predicts the first token - logits = super().compute_logits(hidden_states, sampling_metadata) + logits = super().compute_logits(hidden_states) if logits is not None: logits.zero_() logits[:, 0] += 1.0 diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index a59aebfac4ff..3c094cfdb553 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -3,11 +3,9 @@ from vllm.model_executor.parameter import (BasevLLMParameter, PackedvLLMParameter) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed __all__ = [ - "SamplingMetadata", "set_random_seed", "BasevLLMParameter", "PackedvLLMParameter", diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 8226437cb189..2110aa2769b9 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -10,7 +10,6 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform @@ -50,7 +49,6 @@ def forward( self, lm_head: VocabParallelEmbedding, hidden_states: torch.Tensor, - sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: if self.logits_as_input: diff --git a/vllm/model_executor/models/apertus.py b/vllm/model_executor/models/apertus.py index f6400b05e110..6dab4ed14345 100644 --- a/vllm/model_executor/models/apertus.py +++ b/vllm/model_executor/models/apertus.py @@ -48,7 +48,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -566,10 +565,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py index be82c2fd5964..1ee378af76c9 100644 --- a/vllm/model_executor/models/arcee.py +++ b/vllm/model_executor/models/arcee.py @@ -399,11 +399,10 @@ def forward( inputs_embeds=inputs_embeds) return model_output - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata) -> Optional[torch.Tensor]: + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: # Compute final logits from hidden states (last pipeline rank only) - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index b6dd55996841..55d16fd75ceb 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -456,10 +455,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index a7cb6b35a4ab..35c1adbdd00b 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -19,7 +19,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -644,10 +643,8 @@ def forward( return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 687c82ded9d0..0f05f9b4efcd 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -16,7 +16,6 @@ get_optimal_tiled_canvas) from vllm.config import VllmConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, @@ -464,7 +463,5 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index ae2503341040..db8d0a871047 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -46,7 +46,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, row_parallel_weight_loader) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant @@ -421,10 +420,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py index 5f6025abf315..82cd4a26a1ba 100644 --- a/vllm/model_executor/models/bailing_moe.py +++ b/vllm/model_executor/models/bailing_moe.py @@ -51,7 +51,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -623,10 +622,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index 397089f31cdf..584981ef3ebf 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -34,7 +34,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType @@ -571,10 +570,8 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index a3131aa3812e..b7455fba62c0 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -12,7 +12,6 @@ from vllm.config import CacheConfig, VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -704,10 +703,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 4c37622b049c..30816f72a267 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP, SupportsQuant @@ -355,10 +354,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 7a5623648374..79d648d749c6 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -28,7 +28,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, row_parallel_weight_loader) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -1046,10 +1045,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) # Disallow image tokens which does not include special # begin-image and end-image tokens diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 1fc2da3e4d7c..879508400222 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig @@ -437,10 +436,8 @@ def __init__( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index 179cc2af8eb3..6d67eb68d51a 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -21,7 +21,6 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, @@ -478,7 +477,5 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 7f87e31abdcd..f3929ef3b593 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -46,7 +46,6 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, row_parallel_weight_loader) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -448,15 +447,14 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: is_not_lora = hasattr(self.model.embed_tokens, 'weight') if is_not_lora: logits = self.logits_processor(self.model.embed_tokens, - hidden_states, sampling_metadata) + hidden_states) else: logits = self.logits_processor(self.model.embed_tokens.base_layer, - hidden_states, sampling_metadata) + hidden_states) return logits diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 003cf4563a22..f863b1da5505 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -24,7 +24,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -462,10 +461,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 59c992188149..ffc843fe033c 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -49,7 +49,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -488,10 +487,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 2770ddebc48a..ed7e7614800f 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -19,7 +19,6 @@ default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.deepseek_v2 import (DeepseekV2DecoderLayer, DeepseekV3ForCausalLM) -from vllm.model_executor.sampling_metadata import SamplingMetadata from .utils import AutoWeightsLoader, maybe_prefix @@ -222,10 +221,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 8fbf16d206a8..92f311ab465b 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .deepseek_v2 import (DeepseekV2DecoderLayer, @@ -124,15 +123,13 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: current_step_idx = (spec_step_idx % self.num_mtp_layers) mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states), - sampling_metadata) + mtp_layer.shared_head(hidden_states)) return logits @@ -161,11 +158,9 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - return self.model.compute_logits(hidden_states, sampling_metadata, - spec_step_idx) + return self.model.compute_logits(hidden_states, spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 636554bd648f..a99a6679a569 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -56,7 +56,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import cdiv, direct_register_custom_op @@ -914,10 +913,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index d7ae8206baca..c8ed759d2e97 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -15,7 +15,6 @@ from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models.transformers import replace_linear_class @@ -647,10 +646,8 @@ def forward(self, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/dots1.py b/vllm/model_executor/models/dots1.py index 20555e48b73d..2a09234b59ed 100644 --- a/vllm/model_executor/models/dots1.py +++ b/vllm/model_executor/models/dots1.py @@ -52,7 +52,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -534,10 +533,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index ebab018ed67e..d262e9e9da50 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -49,7 +49,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -591,10 +590,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 0d4aced93ca1..74b358034ef3 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -39,7 +39,6 @@ from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -1292,11 +1291,9 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: """compute logits""" - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def _vision_forward( self, diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 7f791852ceb9..f55016f7ccb3 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -48,7 +48,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .ernie45_moe import Ernie4_5_MoeMLP @@ -587,10 +586,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/ernie_mtp.py b/vllm/model_executor/models/ernie_mtp.py index c44626523031..288fbe736c32 100644 --- a/vllm/model_executor/models/ernie_mtp.py +++ b/vllm/model_executor/models/ernie_mtp.py @@ -36,7 +36,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -138,12 +137,10 @@ def compute_logits( self, hidden_states: torch.Tensor, lm_head: ParallelLMHead, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: self.layers[str(self.mtp_start_layer_idx + spec_step_idx)] - logits = self.logits_processor(lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(lm_head, hidden_states) return logits @@ -180,11 +177,10 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: return self.model.compute_logits(hidden_states, self.lm_head, - sampling_metadata, spec_step_idx) + spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index f503fb0f9364..5dafcd595e4a 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -49,7 +49,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -534,10 +533,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py index 9f7d57d93814..c78eedff6670 100644 --- a/vllm/model_executor/models/exaone4.py +++ b/vllm/model_executor/models/exaone4.py @@ -45,7 +45,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -517,10 +516,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 42c378e5c389..0c50056d1c52 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -46,7 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import RWConfig @@ -496,10 +495,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index 757051b3b144..83efdd2e433f 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -33,7 +33,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP @@ -675,10 +674,8 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 90af859ab92e..53e9e6fe6e46 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -29,7 +29,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.models.persimmon import PersimmonForCausalLM -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -389,10 +388,9 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.language_model.logits_processor( - self.language_model.lm_head, hidden_states, sampling_metadata) + self.language_model.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 12eb27503870..c19425b6cb6d 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -412,10 +411,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.embed_tokens, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 0bdb6c6bf7ae..3f76e1e7d42a 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -41,7 +41,6 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -409,10 +408,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.embed_tokens, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 7246308d5902..77c0ef8cb91d 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -41,7 +41,6 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from ...attention.layers.encoder_only_attention import EncoderOnlyAttention @@ -542,10 +541,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.embed_tokens, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index bee9fbd2c084..0630ee07c347 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -14,7 +14,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -704,10 +703,8 @@ def prepare_attn_masks( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index ffec3408702c..f4d288fd887e 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -43,7 +43,6 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsQuant @@ -814,10 +813,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: Optional[SamplingMetadata], ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.model.embed_tokens, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.model.embed_tokens, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 8d3079aee0df..2acdba54a257 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -25,7 +25,6 @@ from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -685,10 +684,8 @@ def forward(self, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py index 5e2908a82c41..b9d5e24e9f6f 100644 --- a/vllm/model_executor/models/glm4.py +++ b/vllm/model_executor/models/glm4.py @@ -40,7 +40,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -289,10 +288,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 308b0cb602bc..56ec63438690 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -52,7 +52,6 @@ parallel_state) from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -1654,10 +1653,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 1acbd18091fb..947c6ce62f55 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -51,7 +51,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -703,10 +702,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py index 322c5619c178..c572978e6220 100644 --- a/vllm/model_executor/models/glm4_moe_mtp.py +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -38,7 +38,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name @@ -155,15 +154,13 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: current_step_idx = (spec_step_idx % self.num_mtp_layers) mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)] logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states), - sampling_metadata) + mtp_layer.shared_head(hidden_states)) return logits @@ -192,11 +189,9 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - return self.model.compute_logits(hidden_states, sampling_metadata, - spec_step_idx) + return self.model.compute_logits(hidden_states, spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 0f6521e44e6b..24274db148bd 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from ..layers.pooler import DispatchPooler, Pooler @@ -307,10 +306,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 745d0b775999..162018450e7c 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -329,10 +328,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 77df6ae6f30c..698387fab946 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -41,7 +41,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -329,10 +328,9 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata, self.lm_head.bias) + self.lm_head.bias) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index e97db188e27e..7570aefb6e96 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -40,7 +40,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -321,10 +320,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.embed_out, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.embed_out, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index b49fd0d8f88a..4fe59f91124d 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -24,7 +24,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import cdiv @@ -670,10 +669,8 @@ def forward(self, return self.model(input_ids, positions, intermediate_tensors, inputs_embeds) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 4f9cc2532bd8..795b38e724ea 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -48,7 +48,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -463,11 +462,9 @@ def forward( inputs_embeds) return model_output - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index 221023f1fb65..a5849184339b 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -37,7 +37,6 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -776,12 +775,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits( - hidden_states, - sampling_metadata, - ) + return self.language_model.compute_logits(hidden_states) def load_weights( self, diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index da16c72000c0..07200fef4799 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -48,7 +48,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -511,11 +510,9 @@ def forward( inputs_embeds) return hidden_states - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 79c6d8146ba9..e89a1a4a0f7d 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -32,7 +32,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType @@ -672,10 +671,8 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index 0b568a4b2268..a5d118f084e6 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -25,7 +25,6 @@ QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE @@ -311,11 +310,9 @@ def forward( inputs_embeds) return hidden_states - def compute_logits( - self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, + hidden_states: torch.Tensor) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index a59113438337..996e41fe84ff 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -46,7 +46,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -528,10 +527,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/hunyuan_v1.py b/vllm/model_executor/models/hunyuan_v1.py index 4110c8a1fd08..8a23a6b45bc7 100644 --- a/vllm/model_executor/models/hunyuan_v1.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -54,7 +54,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP @@ -1004,10 +1003,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def make_empty_intermediate_tensors( diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 870addd0dcbc..54167f9f1099 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -31,7 +31,6 @@ from vllm.config import VllmConfig from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -962,10 +961,8 @@ def _prepare_multimodal_kwargs(self, **kwargs: object): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights( self, diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 9153a0e2c1e5..18446d126b51 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -738,10 +737,8 @@ def forward( return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 19a3ef1a3b80..8fdf70e35a2b 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -13,11 +13,9 @@ if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import Pooler - from vllm.model_executor.sampling_metadata import SamplingMetadata else: VllmConfig = Any Pooler = Any - SamplingMetadata = Any logger = init_logger(__name__) @@ -100,7 +98,6 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): def compute_logits( self, hidden_states: T, - sampling_metadata: SamplingMetadata, ) -> Optional[T]: """Return `None` if TP rank > 0.""" ... diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index ce94328797ed..221ff08b4384 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -29,7 +29,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -358,10 +357,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.output, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.output, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index b59d1b88cf5c..ba72c288b2b1 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -21,7 +21,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interns1_vit import InternS1VisionModel from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -812,10 +811,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 6a5c565b52e8..f4004e518e3b 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -25,7 +25,6 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -1399,10 +1398,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 4fee8c32fd58..0eb1578b4361 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -42,7 +42,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import JAISConfig @@ -332,10 +331,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 5b8fbc722686..12a49029195f 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -32,7 +32,6 @@ from vllm.model_executor.models.llama import LlamaMLP as JambaMLP from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType @@ -581,10 +580,8 @@ def get_mamba_state_shape_from_config( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index afe33b4d4ad2..2e5e276cc1c7 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -21,7 +21,6 @@ from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) @@ -1556,10 +1555,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 94a5933a6141..f554077935bf 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -67,7 +67,6 @@ SupportsPP) from vllm.model_executor.models.moonvit import MoonVitPretrainedModel from vllm.model_executor.models.utils import merge_multimodal_embeddings -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -484,10 +483,8 @@ def forward( return hidden_states def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, **kwargs) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata, **kwargs) + logits = self.logits_processor(self.lm_head, hidden_states, **kwargs) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/lfm2.py b/vllm/model_executor/models/lfm2.py index 927f78c4e4b4..dd97afbeb668 100644 --- a/vllm/model_executor/models/lfm2.py +++ b/vllm/model_executor/models/lfm2.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, @@ -542,10 +541,8 @@ def forward( inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f8ea2111fed5..1b03cbef501b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,7 +48,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP @@ -601,10 +600,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 7027138dfcb1..fb10af6c53c9 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -21,7 +21,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaForCausalLM) -from vllm.v1.sample.metadata import SamplingMetadata from .utils import AutoWeightsLoader, maybe_prefix @@ -244,10 +243,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) if self.draft_id_to_target_id is None: assert logits.shape[1] == self.config.vocab_size, \ "Expected logits to have shape " \ diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 4f15e1b5762e..e2d7b9f23b28 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -20,7 +20,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -760,10 +759,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index beb3c3310059..c9133fde1455 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -13,7 +13,6 @@ get_anyres_image_grid_shape, unpad_image) from vllm.config import VllmConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.parse import ImageSize @@ -563,10 +562,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index cf9852de633f..610fb188d57d 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -13,7 +13,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.clip import CLIPVisionModel -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -464,10 +463,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 46d54452a52d..cee9ddaf94cc 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -14,7 +14,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -934,10 +933,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 9d1017dac8aa..36141a5d5064 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -26,7 +26,6 @@ IsAttentionFree, SupportsPP) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType @@ -299,10 +298,8 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index b1a4138cb8f6..9c3108146d2e 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -30,7 +30,6 @@ IsAttentionFree) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType @@ -335,10 +334,8 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index b0a96fca2ff8..0ae59dc8dfc2 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -104,12 +104,11 @@ def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]: def compute_logits( self, hidden_states: list[torch.Tensor], - sampling_metadata, ) -> list[torch.Tensor]: logits_lst: list[torch.Tensor] = [] for hs, lm_head in zip(hidden_states, self.lm_heads): - _logits = self.logits_processor(lm_head, hs, sampling_metadata) + _logits = self.logits_processor(lm_head, hs) if _logits is None: # _logits should only be None on rank > 0, in which case diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 140800dd41c7..82648ba668ca 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -42,7 +42,6 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -784,9 +783,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.decoder.compute_logits(hidden_states, sampling_metadata) + return self.decoder.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index ea5292d0df20..d256c1f3eed7 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM, Qwen2Model -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix @@ -183,9 +182,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: hidden_states = self.model.norm(hidden_states) - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index 09194e9f95d0..b4abe458e477 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -34,7 +34,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .utils import maybe_prefix @@ -140,12 +139,10 @@ def compute_logits( self, hidden_states: torch.Tensor, lm_head: ParallelLMHead, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> torch.Tensor: self.mtp_layers[str(self.mtp_start_layer_idx + spec_step_idx)] - logits = self.logits_processor(lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(lm_head, hidden_states) return logits @@ -178,11 +175,10 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: return self.model.compute_logits(hidden_states, self.lm_head, - sampling_metadata, spec_step_idx) + spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 240c23ea2b25..0986ea07406a 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -51,7 +51,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -583,10 +582,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/minicpm_eagle.py b/vllm/model_executor/models/minicpm_eagle.py index 848a97b8bb2a..2af0d546ce63 100644 --- a/vllm/model_executor/models/minicpm_eagle.py +++ b/vllm/model_executor/models/minicpm_eagle.py @@ -39,7 +39,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -376,10 +375,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 9b2d84e32151..a17c4f004d75 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -50,7 +50,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -1194,9 +1193,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.llm.compute_logits(hidden_states, sampling_metadata) + return self.llm.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 6ce883be0a83..1d2c7dea811e 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -41,7 +41,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import maybe_prefix -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid @@ -742,10 +741,8 @@ def forward(self, return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states.float(), - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states.float()) return logits diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index cc7db849a28b..b2f020f3323e 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.sequence import IntermediateTensors @@ -420,10 +419,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index d15776a39362..94e3d7234b6f 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -20,7 +20,6 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -606,10 +605,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 8b3474d80953..bebf0b5adac5 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -49,7 +49,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP @@ -594,10 +593,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 2f0e8a2a5e57..131a66b71323 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -856,10 +855,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def separate_weights( self, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 2475fe131609..201bf83cac58 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -26,7 +26,6 @@ get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather) -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU, SiluAndMul) from vllm.model_executor.layers.layernorm import RMSNorm @@ -1527,10 +1526,8 @@ def forward( return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 48ac91fa6dde..64d669e8ac3e 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -25,7 +25,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -320,10 +319,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 4f8652c00694..ae50f1aefc6f 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -37,7 +37,6 @@ init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs, MultiModalKwargsItems, @@ -1192,10 +1191,8 @@ def get_mm_mapping(self) -> MultiModelKeys: def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): adapter_dict = dict(self.mlp1.named_parameters()) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 21f785e4b91a..6bb2f7392cb4 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -45,7 +45,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronConfig @@ -498,10 +497,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 1e1f0524bd06..ff571541a60a 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -54,7 +54,6 @@ from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronHConfig from vllm.utils import LayerBlockType @@ -622,10 +621,8 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index f8e38dcd80b5..d474c8db41b2 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -44,7 +44,6 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.llama import LlamaAttention, LlamaMLP -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import HasNoOps, SupportsLoRA, SupportsPP @@ -468,10 +467,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index acda2027401d..3abbff8c717d 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -26,7 +26,6 @@ BaseInternVLProcessingInfo, InternVLImageEmbeddingInputs, InternVLImageInputs, InternVLImagePixelInputs, InternVLProcessor) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import NestedTensors @@ -632,10 +631,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 7be3c16528b5..9fa8760073c1 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -391,10 +390,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 3e4c580a1121..2e0b1fb2a13f 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -54,7 +54,6 @@ from vllm.model_executor.models.utils import ( AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Olmo3Config @@ -427,10 +426,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 892e967e4a21..77ece544d490 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -471,10 +470,8 @@ def forward( inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 365aab205b21..4c3ce9f61efb 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -399,10 +398,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 944a9151d75d..586fea343d6f 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -28,7 +28,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -339,10 +338,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index f1bb18716b40..052e143b27f6 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -39,7 +39,6 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -558,9 +557,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.llm.compute_logits(hidden_states, sampling_metadata) + logits = self.llm.compute_logits(hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index 5e4758ef8ea5..f18e38ce154d 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -19,7 +19,6 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -630,9 +629,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.llm.compute_logits(hidden_states, sampling_metadata) + logits = self.llm.compute_logits(hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index d6eec77ebcee..aef510230461 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -9,7 +9,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargsItems, @@ -403,10 +402,8 @@ def forward(self, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 3e854e4d561f..23fb7bb85215 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -44,7 +44,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -334,10 +333,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 6f39afbecf35..9cf288e85005 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -59,7 +59,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -346,10 +345,9 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata, self.lm_head.bias) + self.lm_head.bias) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4522c7043d01..a2b201fe4228 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -29,7 +29,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -681,10 +680,8 @@ def forward(self, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index 25df9e9261d9..d2a3a8cc0496 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -1451,10 +1450,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py index aa7c434a44ae..ae153558e37a 100644 --- a/vllm/model_executor/models/phi4flash.py +++ b/vllm/model_executor/models/phi4flash.py @@ -29,7 +29,6 @@ SupportsV0Only) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .utils import make_layers, maybe_prefix @@ -695,12 +694,10 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: processed_logits = self.logits_processor( self.lm_head, hidden_states, - sampling_metadata, self.embedding_bias, ) return processed_logits diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index b3fc55dab6ec..47b5ad55ab2d 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -18,7 +18,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -1257,10 +1256,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 01d16f1f2c38..3ce67ce37a7a 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -47,7 +47,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -667,10 +666,8 @@ def forward( inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 142d3251bc67..7b197844c8b6 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -32,7 +32,6 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalUUIDDict, NestedTensors) @@ -480,10 +479,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 9f1ee36366fd..33ee1cf44afd 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -52,7 +52,6 @@ from vllm.model_executor.models.utils import ( is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -1022,10 +1021,8 @@ def get_mamba_state_shape_from_config( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 747094849900..e0c08a6a8827 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -282,10 +281,8 @@ def __init__( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e13e87b93429..c536b0f60c30 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -49,7 +49,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import is_interleaved @@ -510,10 +509,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index a7e71309b607..5f27230c913b 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -50,7 +50,6 @@ from vllm.model_executor.models.qwen2_audio import ( Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths) from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, @@ -955,10 +954,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index dbf486374bcf..73b27572a8eb 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -43,7 +43,6 @@ from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.layernorm import RMSNorm # yapf: disable @@ -1256,10 +1255,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index c797b71b5d2e..762ab42e5929 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -34,7 +34,6 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (AudioItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, @@ -481,10 +480,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 6c6276a93045..6a9acaf2c3fe 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -51,7 +51,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -546,10 +545,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index dd4e7731e0b0..b3c42c257256 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -46,7 +46,6 @@ from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -1527,10 +1526,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index dddb47048a1f..ae72fd30c399 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP @@ -328,10 +327,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 029309c49efd..0661b3707ff4 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -54,7 +54,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP @@ -690,10 +689,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index ce917f92bd2e..24cebc5bfdd8 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -53,7 +53,6 @@ default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -1208,10 +1207,8 @@ def get_mamba_state_shape_from_config( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + return self.logits_processor(self.lm_head, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen3_next_mtp.py b/vllm/model_executor/models/qwen3_next_mtp.py index c755eeb9b4ea..c054339842e6 100644 --- a/vllm/model_executor/models/qwen3_next_mtp.py +++ b/vllm/model_executor/models/qwen3_next_mtp.py @@ -19,7 +19,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen3_next import (Qwen3NextDecoderLayer, Qwen3NextRMSNorm) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig @@ -266,11 +265,9 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, spec_step_idx: int = 0, ) -> Optional[torch.Tensor]: - return self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + return self.logits_processor(self.lm_head, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index ca232e03767b..aa28c07ddceb 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -45,7 +45,6 @@ from vllm.config import VllmConfig from vllm.distributed import get_pp_group from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -1493,10 +1492,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/seed_oss.py b/vllm/model_executor/models/seed_oss.py index e3c7c700f8fa..a217c820fedf 100644 --- a/vllm/model_executor/models/seed_oss.py +++ b/vllm/model_executor/models/seed_oss.py @@ -47,7 +47,6 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -472,10 +471,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 9857ccdcbe2d..893ce4497c31 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -22,7 +22,6 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import convert_image_mode from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -897,10 +896,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 94c862258b7a..c774171b9dcd 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -47,7 +47,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -495,10 +494,8 @@ def forward( inputs_embeds) return model_output - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 9e880ebd5081..e4dfe8d5a9a3 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -42,7 +42,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -332,10 +331,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 62ff9b618275..7f379ab95a03 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -43,7 +43,6 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -339,10 +338,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 6a5b540fc817..0cce0c78f8dc 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -29,7 +29,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP @@ -405,10 +404,8 @@ def forward(self, inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index c2940f8e4445..f667266b77bf 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -23,7 +23,6 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -1055,10 +1054,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index c66867315e55..67cf3ccf315d 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -23,7 +23,6 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.llava import LlavaDummyInputsBuilder -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems @@ -638,10 +637,8 @@ def forward( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 3bd4d10316ec..475a68bc642b 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -41,7 +41,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalUUIDDict, @@ -798,10 +797,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index f1f11c5fe8f0..12ae9487ad9d 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -18,7 +18,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors) @@ -616,10 +615,8 @@ def forward(self, inputs_embeds=inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 16a97389cd21..b33e8d09c4be 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -30,7 +30,6 @@ # yapf: disable from vllm.model_executor.models.whisper import WhisperEncoder # yapf: enable -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, MultiModalUUIDDict, @@ -454,10 +453,8 @@ def _parse_and_validate_audio_arrays( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) + return self.language_model.compute_logits(hidden_states) @classmethod def get_speech_to_text_config(cls, model_config: ModelConfig, diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 41ae7b129782..de3e4f0592a6 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) @@ -936,10 +935,8 @@ def _parse_and_validate_audio_input( return WhisperAudioInputs(input_features=input_features) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.proj_out, hidden_states, - sampling_metadata) + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.proj_out, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index e601bc3adb6e..4350e38e02f9 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import HasInnerState, IsHybrid @@ -1036,7 +1035,6 @@ def get_seqlen_agnostic_capture_inputs( def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: """Compute logits for next token prediction. @@ -1047,8 +1045,7 @@ def compute_logits( Returns: Logits for next token prediction """ - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py deleted file mode 100644 index 8c4548ff7f7d..000000000000 --- a/vllm/model_executor/sampling_metadata.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - - -class SamplingMetadata: - # Placeholder until it can be safely removed. - pass diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 2a178ddf4877..5dacf6088696 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -239,7 +239,7 @@ def propose( else: last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + logits = self.model.compute_logits(sample_hidden_states) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: @@ -367,8 +367,7 @@ def propose( else: last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] - logits = self.model.compute_logits(last_hidden_states[:batch_size], - None) + logits = self.model.compute_logits(last_hidden_states[:batch_size]) draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) @@ -678,9 +677,7 @@ def propose_tree( # Get the output logits for the draft tokens. logits = self.model.compute_logits( draft_last_hidden_states.reshape(batch_size * level_num_drafts, - -1), - None, - ) + -1)) # Sample a draft token for each child at the next tree level. num_children = self.child_drafts_per_level[level + 1] diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 3e90179e78d9..70b29c05c2a5 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -41,7 +41,7 @@ def propose( ) -> list[list[int]]: # Generate blocks and compute logits blocks = self.model(target_hidden_states) - logits = self.model.compute_logits(blocks, None) + logits = self.model.compute_logits(blocks) # Get draft tokens and transpose the result # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d0946e8c5d7d..b0cd0f413307 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2240,7 +2240,7 @@ def execute_model( return output sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + logits = self.model.compute_logits(sample_hidden_states) else: # Rare case. assert not self.is_pooling_model @@ -2258,8 +2258,7 @@ def execute_model( logits = None else: sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, - None) + logits = self.model.compute_logits(sample_hidden_states) model_output_broadcast_data = {} if logits is not None: @@ -2706,7 +2705,7 @@ def _get_prompt_logprobs_dict( req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc.np[req_idx].item() prompt_hidden_states = hidden_states[offset:offset + num_logits] - logits = self.model.compute_logits(prompt_hidden_states, None) + logits = self.model.compute_logits(prompt_hidden_states) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want @@ -3105,7 +3104,7 @@ def _dummy_sampler_run( # To avoid breaking the sampler, we use a random tensor here instead. hidden_states = torch.rand_like(hidden_states) - logits = self.model.compute_logits(hidden_states, None) + logits = self.model.compute_logits(hidden_states) num_reqs = logits.size(0) dummy_tensors = lambda v: torch.full( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 48070c1e3e7c..dd11b1dcbe94 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1692,7 +1692,7 @@ def select_hidden_states(self, hidden_states, indices_do_sample): @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def compute_logits(self, sample_hidden_states: torch.Tensor) -> torch.Tensor: - return self.model.compute_logits(sample_hidden_states, None) + return self.model.compute_logits(sample_hidden_states) # TODO: Under SPMD mode, sample_from_logits has correctness issue. # Re-enable the torch.compile once the issue is fixed in torchxla. From 9f092a0e417a1abab71dd1918d1276d70ded3cc4 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 22 Sep 2025 04:12:45 +0800 Subject: [PATCH 09/17] [Perf] Further optimization for Qwen3-VL `fast_pos_embed_interpolate` (#25347) Signed-off-by: Isotr0py --- vllm/model_executor/models/qwen3_vl.py | 50 ++++++++++++++++---------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index aa28c07ddceb..98d65dea2739 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -405,25 +405,39 @@ def fast_pos_embed_interpolate(self, dh = h_idxs - h_floor dw = w_idxs - w_floor - w00 = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1) - w01 = ((1 - dh)[:, None] * dw[None, :]).reshape(-1) - w10 = (dh[:, None] * (1 - dw)[None, :]).reshape(-1) - w11 = (dh[:, None] * dw[None, :]).reshape(-1) - - idx00 = (h_floor[:, None] * num_grid_per_side + - w_floor[None, :]).reshape(-1) - idx01 = (h_floor[:, None] * num_grid_per_side + - w_ceil[None, :]).reshape(-1) - idx10 = (h_ceil[:, None] * num_grid_per_side + - w_floor[None, :]).reshape(-1) - idx11 = (h_ceil[:, None] * num_grid_per_side + - w_ceil[None, :]).reshape(-1) - - indices = torch.stack([idx00, idx01, idx10, idx11], dim=0) + # Create meshgrid view for all h, w vars + dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing='ij') + h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, + w_floor, + indexing='ij') + h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, + w_ceil, + indexing='ij') + h_floor_grid_idx = h_floor_grid * num_grid_per_side + h_ceil_grid_idx = h_ceil_grid * num_grid_per_side + + # original computation of weights + # w00 = (1 - dh_grid) * (1 - dw_grid) + # w01 = (1 - dh_grid) * dw_grid + # w10 = dh_grid * (1 - dw_grid) + # w11 = dh_grid * dw_grid + # we reuse w11 here to avoid duplicate + # dh_grid * dw_grid computation + w11 = dh_grid * dw_grid + w10 = dh_grid - w11 + w01 = dw_grid - w11 + w00 = 1 - dh_grid - dw_grid + w11 + + idx00 = h_floor_grid_idx + w_floor_grid + idx01 = h_floor_grid_idx + w_ceil_grid + idx10 = h_ceil_grid_idx + w_floor_grid + idx11 = h_ceil_grid_idx + w_ceil_grid + + indices = torch.stack([idx00, idx01, idx10, idx11], + dim=0).reshape(4, -1) weights = torch.stack([w00, w01, w10, w11], - dim=0).to(dtype=self.dtype, - device=self.device) - weights = weights.unsqueeze(-1) + dim=0).reshape(4, -1, 1) + weights = weights.to(dtype=self.dtype, device=self.device) embeds = self.pos_embed(indices) weighted_embeds = embeds * weights From 6217239d247c29fbff02958bd9282ebdbb42f3af Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 21 Sep 2025 16:03:28 -0700 Subject: [PATCH 10/17] Remove V0 attention backends (#25351) Signed-off-by: Woosuk Kwon --- examples/offline_inference/qwen_1m.py | 1 - tests/compile/test_fusion_attn.py | 5 +- tests/kernels/attention/test_attention.py | 6 +- .../attention/test_attention_selector.py | 1 + .../kernels/attention/test_prefix_prefill.py | 6 +- .../attention/test_rocm_attention_selector.py | 1 + tests/kernels/utils.py | 66 +- tests/models/test_initialization.py | 5 +- .../backends/differential_flash_attn.py | 931 ---------- .../backends/dual_chunk_flash_attn.py | 1495 ----------------- vllm/attention/backends/flash_attn.py | 929 ---------- vllm/attention/backends/flashmla.py | 227 --- vllm/attention/backends/mla/__init__.py | 0 vllm/attention/backends/mla/common.py | 1305 -------------- vllm/attention/backends/rocm_aiter_mla.py | 407 ----- vllm/attention/backends/rocm_flash_attn.py | 953 ----------- vllm/attention/backends/triton_mla.py | 111 -- vllm/attention/backends/utils.py | 14 +- vllm/attention/backends/xformers.py | 805 --------- vllm/config/model.py | 7 +- .../kv_transfer/kv_connector/utils.py | 2 +- vllm/engine/arg_utils.py | 15 +- vllm/envs.py | 1 - .../layers/mamba/mamba2_metadata.py | 19 +- vllm/model_executor/models/deepseek_v2.py | 3 +- vllm/platforms/cuda.py | 139 +- vllm/platforms/rocm.py | 61 +- vllm/utils/__init__.py | 2 - 28 files changed, 142 insertions(+), 7375 deletions(-) delete mode 100644 vllm/attention/backends/differential_flash_attn.py delete mode 100644 vllm/attention/backends/dual_chunk_flash_attn.py delete mode 100755 vllm/attention/backends/flash_attn.py delete mode 100644 vllm/attention/backends/flashmla.py delete mode 100644 vllm/attention/backends/mla/__init__.py delete mode 100644 vllm/attention/backends/mla/common.py delete mode 100644 vllm/attention/backends/rocm_aiter_mla.py delete mode 100644 vllm/attention/backends/rocm_flash_attn.py delete mode 100644 vllm/attention/backends/triton_mla.py delete mode 100644 vllm/attention/backends/xformers.py diff --git a/examples/offline_inference/qwen_1m.py b/examples/offline_inference/qwen_1m.py index d8d61667f688..c8d0d91ce7b5 100644 --- a/examples/offline_inference/qwen_1m.py +++ b/examples/offline_inference/qwen_1m.py @@ -5,7 +5,6 @@ from vllm import LLM, SamplingParams -os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN" os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index b6bebbba915b..c3f1c7481d1b 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -334,8 +334,9 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): [7, 256, 533] if current_platform.is_cuda() else [8]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("model_name, model_class", MODELS) -@pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if - current_platform.is_cuda() else [_Backend.ROCM_FLASH]) +@pytest.mark.parametrize("backend", + [_Backend.FLASHINFER] if current_platform.is_cuda() + else [_Backend.TRITON_ATTN_VLLM_V1]) @pytest.mark.parametrize( "split_attention", [False, True] if current_platform.is_rocm() else [False]) diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 7083661575ef..c7abf652f111 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -18,7 +18,7 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask - from vllm.attention.backends.xformers import _make_alibi_bias + from tests.kernels.utils import make_alibi_bias FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. @@ -429,8 +429,8 @@ def test_multi_query_kv_attention( alibi_bias = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, - seq_lens) + attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, + seq_lens) output = torch.empty_like(query) start = 0 # Dynamic sequence length not supported with custom attn_bias. diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index f8454ad0a4c4..38ab40f88ae0 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -67,6 +67,7 @@ def generate_params(): return params +@pytest.mark.skip(reason="Skipped for now. Should be revisited.") @pytest.mark.parametrize("device, name, use_mla, block_size", generate_params()) def test_env( diff --git a/tests/kernels/attention/test_prefix_prefill.py b/tests/kernels/attention/test_prefix_prefill.py index 8544eab3accc..0695f84aea1a 100644 --- a/tests/kernels/attention/test_prefix_prefill.py +++ b/tests/kernels/attention/test_prefix_prefill.py @@ -11,7 +11,7 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask -from vllm.attention.backends.xformers import _make_alibi_bias +from tests.kernels.utils import make_alibi_bias from vllm.attention.ops.chunked_prefill_paged_decode import ( chunked_prefill_paged_decode) from vllm.attention.ops.prefix_prefill import context_attention_fwd @@ -470,7 +470,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: key = key.unsqueeze(0) value = value.unsqueeze(0) - attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) + attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) output_ref = torch.empty_like(output) seq_start = 0 query_start = 0 @@ -479,7 +479,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # FIXME(DefTruth): Because xformers does not support dynamic sequence # lengths with custom attention bias, we process each prompt one by # one. This is inefficient, especially when we have many short prompts. - # modified from: vllm/attention/backends/xformers.py#L343 + # modified from: vllm/v1/attention/backends/xformers.py#L343 for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)): seq_end = seq_start + seq_len query_end = query_start + query_len diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index d56d3f4638f1..af301d9de435 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -16,6 +16,7 @@ def clear_cache(): _cached_get_attn_backend.cache_clear() +@pytest.mark.skip(reason="Skipped for now. Should be revisited.") def test_selector(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH") diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index c9bf85f6e2a5..8d6ce381976b 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -513,10 +513,6 @@ def make_backend(backend_name: str) -> AttentionBackend: Construct the backend instance determined by the backend_name string argument. - "XFORMERS" -> construct xformers backend - - TODO: other backends - Note: at time of writing the Attention wrapper automatically selects its own backend for Attention.forward(); so the backend instance which you generate with this function is not meant to be used for *running* @@ -528,18 +524,68 @@ def make_backend(backend_name: str) -> AttentionBackend: * Backend instance ''' - if backend_name == STR_XFORMERS_ATTN_VAL: - # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. - from vllm.attention.backends.xformers import XFormersBackend - return XFormersBackend() - elif backend_name == STR_FLASH_ATTN_VAL: - from vllm.attention.backends.flash_attn import FlashAttentionBackend + if backend_name in (STR_XFORMERS_ATTN_VAL, "XFORMERS_VLLM_V1"): + from vllm.v1.attention.backends.xformers import ( + XFormersAttentionBackend) + return XFormersAttentionBackend() + if backend_name in (STR_FLASH_ATTN_VAL, "FLASH_ATTN_VLLM_V1"): + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend return FlashAttentionBackend() + if backend_name == "TRITON_ATTN_VLLM_V1": + from vllm.v1.attention.backends.triton_attn import ( + TritonAttentionBackend) + return TritonAttentionBackend() + if backend_name == "FLEX_ATTENTION": + from vllm.v1.attention.backends.flex_attention import ( + FlexAttentionBackend) + return FlexAttentionBackend() + if backend_name in ("TORCH_SDPA", "TORCH_SDPA_VLLM_V1"): + from vllm.v1.attention.backends.cpu_attn import TorchSDPABackend + return TorchSDPABackend() + if backend_name == "FLASHINFER": + from vllm.v1.attention.backends.flashinfer import FlashInferBackend + return FlashInferBackend() raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") +def make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_lens: list[int], +) -> list[Any]: + """Create ALiBi biases compatible with xFormers attention tests.""" + from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias + + if alibi_slopes is None: + return [None for _ in seq_lens] + + attn_biases: list[Any] = [] + num_heads = alibi_slopes.shape[0] + assert num_heads >= num_kv_heads, ( + "ALiBi slopes expect at least as many heads as KV heads") + + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + bias_tensor = torch.empty( + 1, + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias_tensor.mul_(alibi_slopes[:, None, None]) + attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor)) + + return attn_biases + + def _make_metadata_tensors( seq_lens: Optional[list[int]], context_lens: Optional[list[int]], diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index b9601114a318..bfde6e20a3b1 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -78,9 +78,8 @@ def _initialize_kv_caches_v1(self, vllm_config): return if model_arch in ("Phi4FlashForCausalLM", "MotifForCausalLM"): - # Phi4FlashForCausalLM and MotifForCausalLM - # only supports DIFFERENTIAL_FLASH_ATTN backend - m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN") + pytest.skip( + "Differential Flash Attention backend has been removed.") if model_arch == "GptOssForCausalLM": # FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU # has cc==8.9 which hasn't supported FA3 yet. Remove this hack when diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py deleted file mode 100644 index 87a4558e377d..000000000000 --- a/vllm/attention/backends/differential_flash_attn.py +++ /dev/null @@ -1,931 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""" An implementation of https://arxiv.org/pdf/2410.05258 """ -from collections import defaultdict -from dataclasses import dataclass -from itertools import accumulate -from typing import Any, Dict, List, Optional, Tuple, Type - -import torch -from einops import rearrange - -from vllm import _custom_ops as ops -# yapf conflicts with isort for this block -# yapf: disable -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.flash_attn import FlashAttentionBackend -# yapf: enable -from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, - compute_slot_mapping, - compute_slot_mapping_start_idx, - is_all_cross_attn_metadata_set, - is_all_encoder_attn_metadata_set, - is_block_tables_empty) -from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) -from vllm.logger import init_logger -from vllm.multimodal import MultiModalPlaceholderMap -from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache) - -logger = init_logger(__name__) - - -class DifferentialFlashAttentionBackend(AttentionBackend): - accept_output_buffer = False - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - assert num_kv_heads % 2 == 0, "num_kv_heads must be divisible by 2" - return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) - - @staticmethod - def get_name() -> str: - return "DIFFERENTIAL_FLASH_ATTN" - - @staticmethod - def get_impl_cls() -> Type["DifferentialFlashAttentionImpl"]: - return DifferentialFlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]: - return DifferentialFlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]: - return DifferentialFlashAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - - ops.copy_blocks(key_caches, value_caches, src_to_dists) - - -@dataclass -class DifferentialFlashAttentionMetadata(AttentionMetadata): - """Metadata for FlashAttentionBackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - - use_cuda_graph: bool - - # Maximum query length in the batch. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional[ - "DifferentialFlashAttentionMetadata"] = None - _cached_decode_metadata: Optional[ - "DifferentialFlashAttentionMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - encoder_seq_start_loc: Optional[torch.Tensor] = None - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - # Cross-layer shared attention block tables - cross_layer_shared_block_tables: Optional[torch.Tensor] = None - - @property - def is_all_encoder_attn_metadata_set(self): - ''' - All attention metadata required for encoder attention is set. - ''' - return is_all_encoder_attn_metadata_set(self) - - @property - def is_all_cross_attn_metadata_set(self): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return is_all_cross_attn_metadata_set(self) - - @property - def prefill_metadata( - self) -> Optional["DifferentialFlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert ((self.seq_lens is not None) - or (self.encoder_seq_lens is not None)) - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - cross_layer_shared_block_tables = ( - None if self.cross_layer_shared_block_tables is None else - self.cross_layer_shared_block_tables[:self.num_prefills]) - - self._cached_prefill_metadata = DifferentialFlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - cross_layer_shared_block_tables=cross_layer_shared_block_tables, - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - encoder_seq_start_loc=self.encoder_seq_start_loc, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_prefill_metadata - - @property - def decode_metadata( - self) -> Optional["DifferentialFlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - cross_layer_shared_block_tables = ( - None if self.cross_layer_shared_block_tables is None else - self.cross_layer_shared_block_tables[self.num_prefills:]) - self._cached_decode_metadata = DifferentialFlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=self.max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - cross_layer_shared_block_tables=cross_layer_shared_block_tables, - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - encoder_seq_start_loc=self.encoder_seq_start_loc, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_decode_metadata - - -class DifferentialFlashAttentionMetadataBuilder( - AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]): - - def __init__(self, input_builder): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.cross_layer_shared_block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, - prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - # TODO: add support for chunked prefill and prefix caching. - assert not chunked_prefill_enabled, \ - "chunked prefill is not supported for now" - assert not prefix_cache_hit, "prefix caching is not supported for now" - - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - cross_layer_shared_block_table = [] - if prefix_cache_hit: - cross_layer_shared_block_table = block_tables[seq_id] - elif block_tables is not None: - if curr_sliding_window_block == 0: - cross_layer_shared_block_table = block_tables[seq_id] - else: - cross_layer_shared_block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.cross_layer_shared_block_tables.append( - cross_layer_shared_block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def _get_graph_runner_block_tables(self, num_seqs: int, - block_tables: List[List[int]], - graph_block_tables) -> torch.Tensor: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - # max_batch_size, max_blocks = self.runner.graph_block_tables.shape - max_batch_size, max_blocks = graph_block_tables.shape - assert max_batch_size >= num_seqs - - # graph_block_tables = self.runner.graph_block_tables[:num_seqs] - graph_block_tables = graph_block_tables[:num_seqs] - for i, block_table in enumerate(block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - graph_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - graph_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - num_seqs = len(seq_lens) - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - - self.cross_layer_shared_block_tables.extend([] * - cuda_graph_pad_size) - - num_decode_tokens = batch_size - self.num_prefill_tokens - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables, self.runner.graph_block_tables) - cross_layer_shared_block_tables = \ - self._get_graph_runner_block_tables( - num_seqs, self.cross_layer_shared_block_tables, - self.runner.cross_layer_shared_graph_block_tables) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - cross_layer_shared_block_tables = make_tensor_with_pad( - self.cross_layer_shared_block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - - return DifferentialFlashAttentionMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=True, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - cross_layer_shared_block_tables=cross_layer_shared_block_tables, - use_cuda_graph=use_captured_graph, - ) - - -class DifferentialFlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - differential_flash_attention_config: Optional[Dict[str, Any]] = None, - ) -> None: - if differential_flash_attention_config is None: - differential_flash_attention_config = {} - self.differential_flash_attention_config = \ - differential_flash_attention_config - self.used_shared_kv_cache = kv_sharing_target_layer_name is not None - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - if use_irope: - logger.warning( - "Using irope in V0 is not supported yet, it will fall back " - "to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window - 1, - 0) if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - self.vllm_flash_attn_version = get_flash_attn_version( - requires_alibi=self.alibi_slopes is not None) - if is_quantized_kv_cache(self.kv_cache_dtype) and ( - not self.kv_cache_dtype.startswith("fp8") - or not flash_attn_supports_fp8()): - raise NotImplementedError( - f"FlashAttention does not support {self.kv_cache_dtype} " - "kv-cache on this device " - f"(FA supports fp8 = {flash_attn_supports_fp8()}).") - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}.") - self.attn_type = attn_type - - self.lambda_full = None - self.subln = self.differential_flash_attention_config["subln"] - - def split_heads(self, x): - # split by num_heads, the stripe pattern is friendly to tensor parallel. - x = rearrange(x, "... (H two) D -> ... H two D", two=2) - x1 = x[..., 0, :] - x2 = x[..., 1, :] - return x1.contiguous(), x2.contiguous() - - def split_kv_cache(self, x): - # split by num_heads, the stripe pattern is friendly to tensor parallel. - if x.numel() == 0: - return torch.empty(0), torch.empty(0) - - x1, x2 = x[0], x[1] - return x1, x2 - - def populate_kv_cache(self, layer: AttentionLayer, key: torch.Tensor, - value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata): - if kv_cache.numel() > 0 and key is not None and value is not None: - updated_slot_mapping = attn_metadata.slot_mapping - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - updated_slot_mapping.flatten(), - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - def forward_generate_kv_cache( - self, query: torch.Tensor, key: Optional[torch.Tensor], - value: Optional[torch.Tensor], k_cache: torch.Tensor, - v_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata) -> torch.Tensor: - - head_size = self.head_size - num_heads = self.num_heads // 2 - num_kv_heads = self.num_kv_heads // 2 - - query = query.view(-1, num_heads, head_size) - if key is not None: - assert value is not None - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - else: - assert value is None - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[ - 0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch" - assert value.shape[ - 0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch" - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - if key is not None and value is not None: - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens, "query shape mismatch" - assert decode_query.shape[ - 0] == num_decode_tokens, "decode query shape mismatch" - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if k_cache.numel() == 0 \ - or prefill_meta.block_tables is None \ - or prefill_meta.block_tables.numel() == 0: - # normal attention - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - fa_version=self.vllm_flash_attn_version, - ) - assert prefill_output.shape == output[: - num_prefill_tokens].shape - output[:num_prefill_tokens] = prefill_output - else: - raise Exception("prefix caching not supported") - - if decode_meta := attn_metadata.decode_metadata: - block_tables_arg = decode_meta.block_tables - try: - output[num_prefill_tokens:] = flash_attn_with_kvcache( - q=decode_query.unsqueeze(1), - k_cache=k_cache, - v_cache=v_cache, - block_table=block_tables_arg, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - fa_version=self.vllm_flash_attn_version, - ).squeeze(1) - except Exception as e: - logger.error("Error in PagedAttention.forward_decode: %s", - str(e)) - raise e - - # Reshape the output tensor. - return output.view(-1, num_heads, head_size) - - def forward_with_kv_cache_only( - self, - query: torch.Tensor, - k_cache: torch.Tensor, - v_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata, - ): - if not attn_metadata.decode_metadata: - block_tables_arg = attn_metadata.cross_layer_shared_block_tables - else: - block_tables_arg = attn_metadata.block_tables - - output = flash_attn_with_kvcache( - q=query.unsqueeze(1), - k_cache=k_cache, - v_cache=v_cache, - block_table=block_tables_arg, - cache_seqlens=attn_metadata.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - fa_version=self.vllm_flash_attn_version, - ).squeeze(1) - return output - - def forward( - self, - layer: AttentionLayer, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: DifferentialFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention. - - Args: - layer: Attention layer instance. - q: Query tensor with shape = [num_tokens, num_heads, head_size] - k: Key tensor with shape = [num_tokens, num_kv_heads, head_size] - v: Value tensor with shape = [num_tokens, num_kv_heads, head_size] - kv_cache: KV cache tensor with shape - [2, num_blocks, block_size, num_kv_heads, head_size]. - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - output: Output tensor with shape [num_tokens, num_heads, head_size] - output_scale: Optional output scale tensor. - output_block_scale: Optional output block scale tensor. - NOTE: It in-place updates the output tensor. - NOTE: FP8 quantization, flash-attn expect the size of - {q,k,v}_descale to be (num_sequences, num_kv_heads). - We use torch's .expand() to avoid duplicating values - """ - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for DifferentialFlashAttentionImpl") - - if self.lambda_full is None: - self.lambda_init = self.differential_flash_attention_config[ - "lambda_init"] - lambda_q1 = self.differential_flash_attention_config["lambda_q1"] - lambda_k1 = self.differential_flash_attention_config["lambda_k1"] - lambda_q2 = self.differential_flash_attention_config["lambda_q2"] - lambda_k2 = self.differential_flash_attention_config["lambda_k2"] - lambda_1 = torch.exp( - torch.sum(lambda_q1 * lambda_k1, dim=-1).float()).type_as(q) - lambda_2 = torch.exp( - torch.sum(lambda_q2 * lambda_k2, dim=-1).float()).type_as(q) - self.lambda_full = lambda_1 - lambda_2 + self.lambda_init - - if not self.used_shared_kv_cache: # need to generate kv-cache - q = q.view(-1, self.num_heads, self.head_size) - k = k.view(-1, self.num_kv_heads, self.head_size) - v = v.view(-1, self.num_kv_heads, self.head_size) - - q1, q2 = self.split_heads(q) - k1, k2 = self.split_heads(k) - v1, v2 = self.split_heads(v) - - # kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # noqa: E501 - # Split by half along the first dimension. - kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) - assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous" - assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous" - - if kv_cache1.numel() != 0: - self.populate_kv_cache(layer, k1, v1, kv_cache1, attn_metadata) - self.populate_kv_cache(layer, k2, v2, kv_cache2, attn_metadata) - - key_cache1, value_cache1 = self.split_kv_cache(kv_cache1) - key_cache2, value_cache2 = self.split_kv_cache(kv_cache2) - else: - key_cache1, value_cache1 = torch.empty(0), torch.empty(0) - key_cache2, value_cache2 = torch.empty(0), torch.empty(0) - attn11 = self.forward_generate_kv_cache(q1, k1, v1, key_cache1, - value_cache1, - attn_metadata) - attn12 = self.forward_generate_kv_cache(q1, k1, v2, key_cache1, - value_cache2, - attn_metadata) - attn11 = attn11.view(q1.shape) - attn12 = attn12.view(q1.shape) - attn1 = torch.cat([attn11, attn12], dim=-1) - - attn21 = self.forward_generate_kv_cache(q2, k2, v1, key_cache2, - value_cache1, - attn_metadata) - attn22 = self.forward_generate_kv_cache(q2, k2, v2, key_cache2, - value_cache2, - attn_metadata) - attn21 = attn21.view(q2.shape) - attn22 = attn22.view(q2.shape) - attn2 = torch.cat([attn21, attn22], dim=-1) - - attn = attn1 - self.lambda_full * attn2 - # attn shape (-1, self.num_heads // 2, 2 * self.head_dim) - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - # reshape back to 2 * num_head - attn_output = rearrange(attn, - "... H (two D) -> ... (H two) D", - two=2) - - else: # reuse the kv cache, full attention - q = q.view(-1, self.num_heads, self.head_size) - q1, q2 = self.split_heads(q) - # kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) # noqa: E501 - kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) - key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1] - key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1] - - attn11 = self.forward_with_kv_cache_only(q1, key_cache1, - value_cache1, - attn_metadata) - attn12 = self.forward_with_kv_cache_only(q1, key_cache1, - value_cache2, - attn_metadata) - attn11 = attn11.view(q1.shape) - attn12 = attn12.view(q1.shape) - attn1 = torch.cat([attn11, attn12], dim=-1) - - attn21 = self.forward_with_kv_cache_only(q2, key_cache2, - value_cache1, - attn_metadata) - attn22 = self.forward_with_kv_cache_only(q2, key_cache2, - value_cache2, - attn_metadata) - attn21 = attn21.view(q2.shape) - attn22 = attn22.view(q2.shape) - attn2 = torch.cat([attn21, attn22], dim=-1) - - attn = attn1 - self.lambda_full * attn2 - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - # reshape back to 2 * num_head - attn_output = rearrange(attn, - "... H (two D) -> ... (H two) D", - two=2) - attn_output = attn_output.view(-1, self.num_heads * self.head_size) - return attn_output diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py deleted file mode 100644 index de47bb8ebd8f..000000000000 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ /dev/null @@ -1,1495 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with Dual chunk flash attention and sparse attention. -""" -import math -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type - -import torch -import torch.distributed -import torch.nn.functional as F - -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import AttentionLayer, AttentionType -from vllm.attention.backends.flash_attn import (FlashAttentionBackend, - FlashAttentionImpl, - FlashAttentionMetadata, - FlashAttentionMetadataBuilder) -from vllm.distributed.parallel_state import get_tensor_model_parallel_rank -from vllm.logger import init_logger -from vllm.utils import async_tensor_h2d -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache, sparse_attn_func) - -logger = init_logger(__name__) - - -class DualChunkFlashAttentionBackend(FlashAttentionBackend): - - accept_output_buffer: bool = False - - @staticmethod - def get_name() -> str: - return "DUAL_CHUNK_FLASH_ATTN" - - @staticmethod - def get_impl_cls() -> Type["DualChunkFlashAttentionImpl"]: - return DualChunkFlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["DualChunkFlashAttentionMetadata"]: - return DualChunkFlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["DualChunkFlashAttentionMetadataBuilder"]: - return DualChunkFlashAttentionMetadataBuilder - - -@dataclass -class DualChunkFlashAttentionMetadata(FlashAttentionMetadata): - # Block size of the paged kv cache. - block_size: int = 16 - - # Original max position embeddings. - original_max_position_embeddings: int = 0 - - # Chunk size - chunk_size: int = 8192 - - # Local size - local_size: int = 1024 - - # (batch_size,). The orig sequence length per sequence. - orig_seq_lens: Optional[List[int]] = None - - # orig_seq_lens stored as a tensor. - orig_seq_lens_tensor: Optional[torch.Tensor] = None - - # Length scaling factor - scaling_factor: Optional[torch.Tensor] = None - - # (batch_size,). Sequence lengths for intra attention. - seq_lens_intra: Optional[torch.Tensor] = None - - # Max sequence length for intra attention. - max_seq_len_intra: Optional[int] = None - - # (batch_size, num_blocks). Block table for intra attention. - block_tables_intra: Optional[torch.Tensor] = None - - # (batch_size,). Sequence lengths for succ attention. - seq_lens_succ: Optional[torch.Tensor] = None - - # Max sequence length for succ attention. - max_seq_len_succ: Optional[int] = None - - # (batch_size, num_blocks). Block table for succ attention. - block_tables_succ: Optional[torch.Tensor] = None - - # (batch_size,). Sequence lengths for inter attention. - seq_lens_inter: Optional[torch.Tensor] = None - - # Max sequence length for inter attention. - max_seq_len_inter: Optional[int] = None - - _cached_prefill_metadata: Optional[ - "DualChunkFlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["DualChunkFlashAttentionMetadata"] = None - - @property - def prefill_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - prefill_metadata = super().prefill_metadata - if prefill_metadata is None: - return None - - prefill_metadata = DualChunkFlashAttentionMetadata( - **prefill_metadata.asdict_zerocopy()) - - prefill_metadata.orig_seq_lens = ( - None if self.orig_seq_lens is None else - self.orig_seq_lens[:self.num_prefills]) - prefill_metadata.orig_seq_lens_tensor = ( - None if self.orig_seq_lens_tensor is None else - self.orig_seq_lens_tensor[:self.num_prefills]) - - if self.original_max_position_embeddings > 0: - assert prefill_metadata.orig_seq_lens_tensor is not None - prefill_metadata.scaling_factor = ( - 0.1 * torch.log(prefill_metadata.orig_seq_lens_tensor / - self.original_max_position_embeddings) + - 1.0).clip(min=1) - - self._cached_prefill_metadata = prefill_metadata - return prefill_metadata - - @property - def decode_metadata(self) -> Optional["DualChunkFlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - - decode_metadata = super().decode_metadata - if decode_metadata is None: - return None - - decode_metadata = DualChunkFlashAttentionMetadata( - **decode_metadata.asdict_zerocopy()) - - decode_metadata.orig_seq_lens_tensor = ( - None if self.orig_seq_lens_tensor is None else - self.orig_seq_lens_tensor[self.num_prefills:]) - - assert decode_metadata.orig_seq_lens_tensor is not None - assert decode_metadata.block_tables is not None - - cache_seq_lens = decode_metadata.orig_seq_lens_tensor - chunk_len = self.chunk_size - self.local_size - chunk_num_curr = (cache_seq_lens - 1) // chunk_len - batch_size = decode_metadata.num_decode_tokens - - if self.original_max_position_embeddings > 0: - decode_metadata.scaling_factor = (0.1 * torch.log( - cache_seq_lens / self.original_max_position_embeddings) + - 1.0).clip(min=1) - - seq_lens_intra = cache_seq_lens - chunk_num_curr * chunk_len - max_seq_len_intra = seq_lens_intra.max().item() - decode_metadata.seq_lens_intra = seq_lens_intra - decode_metadata.max_seq_len_intra = max_seq_len_intra - - block_tables_intra = torch.zeros( - batch_size, - (max_seq_len_intra - 1) // self.block_size + 1, - dtype=decode_metadata.block_tables.dtype, - device=decode_metadata.block_tables.device, - ) - for i in range(batch_size): - st = chunk_num_curr[i] * chunk_len // self.block_size - ed = min( - st + (max_seq_len_intra - 1) // self.block_size + 1, - (cache_seq_lens[i] - 1) // self.block_size + 1, - ) - block_tables_intra[i, :ed - - st] = decode_metadata.block_tables[i, st:ed] - decode_metadata.block_tables_intra = block_tables_intra - - seq_lens_succ = (chunk_num_curr - - (chunk_num_curr - 1).clip(min=0)) * chunk_len - max_seq_len_succ = seq_lens_succ.max().item() - decode_metadata.seq_lens_succ = seq_lens_succ - decode_metadata.max_seq_len_succ = max_seq_len_succ - if max_seq_len_succ: - block_tables_succ = torch.zeros( - batch_size, - (max_seq_len_succ - 1) // self.block_size + 1, - dtype=decode_metadata.block_tables.dtype, - device=decode_metadata.block_tables.device, - ) - for i in range(batch_size): - start = ((chunk_num_curr[i] - 1).clip(min=0) * chunk_len // - self.block_size) - end = min( - start + (max_seq_len_succ - 1) // self.block_size + 1, - (cache_seq_lens[i] - 1) // self.block_size + 1, - ) - block_tables_succ[ - i, :end - start] = decode_metadata.block_tables[i, - start:end] - decode_metadata.block_tables_succ = block_tables_succ - - seq_lens_inter = (chunk_num_curr - 1).clip(min=0) * chunk_len - max_seq_len_inter = seq_lens_inter.max().item() - decode_metadata.seq_lens_inter = seq_lens_inter - decode_metadata.max_seq_len_inter = max_seq_len_inter - - self._cached_decode_metadata = decode_metadata - return decode_metadata - - -class DualChunkFlashAttentionMetadataBuilder(FlashAttentionMetadataBuilder): - - def prepare(self): - super().prepare() - self.orig_seq_lens: List[int] = [] - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, - prefix_cache_hit: bool): - super()._add_seq_group(inter_data, chunked_prefill_enabled, - prefix_cache_hit) - for prompt_len, seq_len in zip(inter_data.prompt_lens, - inter_data.seq_lens): - self.orig_seq_lens.append(max(prompt_len, seq_len)) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - attn_metadata = super().build(seq_lens, query_lens, - cuda_graph_pad_size, batch_size) - attn_metadata = DualChunkFlashAttentionMetadata( - **attn_metadata.asdict_zerocopy()) - - device = self.runner.device - attn_metadata.orig_seq_lens = self.orig_seq_lens - attn_metadata.orig_seq_lens_tensor = async_tensor_h2d( - self.orig_seq_lens, torch.int, device, self.runner.pin_memory) - - attn_metadata.block_size = self.runner.block_size - dual_chunk_attn_config = getattr(self.runner.model_config.hf_config, - "dual_chunk_attention_config", {}) - attn_metadata.original_max_position_embeddings = \ - dual_chunk_attn_config.get("original_max_position_embeddings", 0) - attn_metadata.chunk_size = dual_chunk_attn_config.get( - "chunk_size", 8192) - attn_metadata.local_size = dual_chunk_attn_config.get( - "local_size", 1024) - - return attn_metadata - - -class DualChunkFlashAttentionImpl(FlashAttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - The prompts might have different lengths, while the generation tokens - always have length 1. - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - layer_idx: int = -1, - dual_chunk_attention_config: Optional[Dict[str, Any]] = None, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "DUAL_CHUNK_FLASH_ATTN backend.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if sliding_window is not None: - # NOTE(woosuk): flash-attn's sliding window does not work with - # paged KV cache. - raise ValueError( - "Sliding window is not supported in FlashAttention.") - - support_head_sizes = ( - DualChunkFlashAttentionBackend.get_supported_head_sizes()) - - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}.") - - assert dual_chunk_attention_config is not None - self.chunk_size = dual_chunk_attention_config.get("chunk_size", 8192) - self.local_size = dual_chunk_attention_config.get("local_size", 1024) - self.original_max_position_embeddings = dual_chunk_attention_config.get( - "original_max_position_embeddings", 0) - self.sparse_attention_config = dual_chunk_attention_config.get( - "sparse_attention_config", None) - if not self.sparse_attention_config: - logger.warning_once("Sparse attention will not be enabled as " - "sparse attention config is not provided.") - self.sparse_attention_enabled = dual_chunk_attention_config.get( - "sparse_attention_enabled", self.sparse_attention_config - is not None) - self.sparse_attention_threshold = dual_chunk_attention_config.get( - "sparse_attention_threshold", 32768) - self.sparse_attention_last_q = dual_chunk_attention_config.get( - "sparse_attention_last_q", 64) - self.layer_idx = layer_idx - self.dual_chunk_attention_config = dual_chunk_attention_config - - if self.sparse_attention_config: - self.sparse_attention_config = { - int(i): j - for i, j in self.sparse_attention_config[ - self.layer_idx].items() - } - start_head = self.num_heads * get_tensor_model_parallel_rank() - end_head = start_head + self.num_heads - self.sparse_attention_config = [ - self.sparse_attention_config[i] - for i in range(start_head, end_head) - ] - - if self.sparse_attention_enabled: - self.arange = torch.arange(self.sparse_attention_last_q, - device="cuda") - self.last_q_mask = (self.arange[None, None, :, None] - >= self.arange[None, None, None, :]) - - def forward( # type: ignore - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: DualChunkFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with DualChunkFlashAttention. - Args: - query: shape = [num_tokens, num_heads * head_size] - query_succ: shape = [num_tokens, num_heads * head_size] - query_inter: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size, num_kv_heads * head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - assert output is None, "Output tensor not supported for DualChunk" - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") - - ( - query, - query_succ, - query_inter, - query_succ_critical, - query_inter_critical, - ) = torch.split(query, query.shape[-1] // 5, dim=-1) - - assert ( - query_succ is not None and query_inter is not None - ), "query_succ and query_inter are required in Dual Chunk Attention." - - num_tokens, hidden_size = query.shape - - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - query_succ = query_succ.view(-1, self.num_heads, self.head_size) - query_inter = query_inter.view(-1, self.num_heads, self.head_size) - query_succ_critical = query_succ_critical.view(-1, self.num_heads, - self.head_size) - query_inter_critical = query_inter_critical.view( - -1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if self.original_max_position_embeddings > 0: - if prefill_meta := attn_metadata.prefill_metadata: - assert prefill_meta.scaling_factor is not None - assert prefill_meta.query_start_loc is not None - assert prefill_meta.orig_seq_lens is not None - current_start = 0 - query_start_loc_cpu = prefill_meta.query_start_loc.cpu() - for i in range(len(prefill_meta.orig_seq_lens)): - current_end = (current_start + - (query_start_loc_cpu[i + 1] - - query_start_loc_cpu[i]).item()) - key[current_start:current_end].mul_( - prefill_meta.scaling_factor[i]) - current_start = current_end - assert current_end <= attn_metadata.num_prefill_tokens - if decode_meta := attn_metadata.decode_metadata: - assert decode_meta.scaling_factor is not None - scaling_factor = decode_meta.scaling_factor - key[attn_metadata.num_prefill_tokens:].mul_( - scaling_factor.unsqueeze(-1).unsqueeze(-1)) - - if kv_cache is not None and kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - output = torch.empty_like(query) - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - decode_query_succ = query_succ[num_prefill_tokens:] - decode_query_inter = query_inter[num_prefill_tokens:] - - # QKV for prefill. - query = query[:num_prefill_tokens] - query_succ = query_succ[:num_prefill_tokens] - query_inter = query_inter[:num_prefill_tokens] - query_succ_critical = query_succ_critical[:num_prefill_tokens] - query_inter_critical = query_inter_critical[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache is None or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention, called during the profiling run. - out = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out - else: - # prefix-enabled attention - assert prefill_meta.seq_lens is not None - assert prefill_meta.orig_seq_lens is not None - output[:num_prefill_tokens] = ( - self._dual_chunk_flash_attn_prefill( - q=query, - q_succ=query_succ, - q_inter=query_inter, - q_succ_critical=query_succ_critical, - q_inter_critical=query_inter_critical, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - orig_seq_lens=prefill_meta.orig_seq_lens, - scaling_factor=prefill_meta.scaling_factor, - softmax_scale=self.scale, - causal=True, - window_size=(-1, -1), - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - chunk_size=self.chunk_size, - local_size=self.local_size, - )) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - output[num_prefill_tokens:] = ( - self._dual_chunk_flash_attn_decoding( - decode_query.unsqueeze(1), - decode_query_succ.unsqueeze(1), - decode_query_inter.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - chunk_size=self.chunk_size, - local_size=self.local_size, - original_max_position_embeddings=self. - original_max_position_embeddings, - decode_meta=decode_meta, - ).squeeze(1)) - # Reshape the output tensor. - return output.view(num_tokens, hidden_size) - - def _dual_chunk_flash_attn_prefill( - self, - q, - q_succ, - q_inter, - q_succ_critical, - q_inter_critical, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - orig_seq_lens: List[int], - scaling_factor: torch.Tensor, - softmax_scale: float, - causal: Optional[bool] = True, - window_size: Tuple[int, int] = (-1, -1), - alibi_slopes: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - chunk_size: int = 8192, - local_size: int = 1024, - ): - if alibi_slopes is not None: - raise ValueError( - "Dual Chunk Attention does not support alibi_slopes") - if not causal: - raise ValueError( - "Dual Chunk Attention does not support causal=False") - if window_size != (-1, -1): - raise ValueError( - "Dual Chunk Attention does not support window_size") - - cu_seqlens_q_cpu = cu_seqlens_q.cpu().tolist() - cu_seqlens_k_cpu = cu_seqlens_k.cpu().tolist() - all_outputs = [] - - for i in range(0, len(cu_seqlens_q_cpu) - 1): - qs = cu_seqlens_q_cpu[i] - qe = cu_seqlens_q_cpu[i:i + 2][-1] - ks = cu_seqlens_k_cpu[i] - ke = cu_seqlens_k_cpu[i:i + 2][-1] - - current_q = q[qs:qe] - current_q_succ = q_succ[qs:qe] - current_q_inter = q_inter[qs:qe] - current_q_succ_critical = q_succ_critical[qs:qe] - current_q_inter_critical = q_inter_critical[qs:qe] - - if block_table is None: - current_k = k[ks:ke] - current_v = v[ks:ke] - current_block_table = None - current_orig_seq_len = orig_seq_lens[i] - else: - current_block_table = block_table[i] - current_orig_seq_len = orig_seq_lens[i] - current_k = k - current_v = v - sparse_attn_enabled = (self.sparse_attention_enabled - and current_orig_seq_len - > self.sparse_attention_threshold) - - if current_q.shape[0] == 0: - continue - - if current_k.shape[0] == 0: - all_outputs.append( - torch.zeros( - (current_q.shape[0], current_q.shape[1], v.shape[2]), - device=q.device, - dtype=q.dtype, - )) - continue - - current_output = torch.empty_like(current_q) - group_size = int(current_q.size(-2) / current_k.size(-2)) - - if sparse_attn_enabled: - num_device_q_heads = current_q.size(-2) - heads_vertical_size = torch.empty(size=(num_device_q_heads, ), - dtype=torch.int32) - heads_slash_size = torch.empty(size=(num_device_q_heads, ), - dtype=torch.int32) - for head_id in range(current_q.size(-2)): - ( - ty, - vertical_size, - slash_size, - _, - ) = self.sparse_attention_config[head_id] - assert ty == "vertical_and_slash", "only support slash mode" - - if vertical_size == 30: - vertical_size += 100 - heads_vertical_size[head_id] = vertical_size - heads_slash_size[head_id] = slash_size - - current_output = self._dual_chunk_flash_attn_prefill_func( - current_q, # allheads - current_q_succ, - current_q_inter, - current_q_succ_critical, - current_q_inter_critical, - current_k, - current_v, - current_block_table, - softmax_scale, - chunk_size, - local_size, - scaling_factor[i].item(), - ke - ks, - sparse_attn_enabled=sparse_attn_enabled, - heads_vertical_size=heads_vertical_size, - heads_slash_size=heads_slash_size, - group_size=group_size) - else: - for head_id in range(current_q.size(-2)): - # (seq_len, num_heads, head_size) - current_q_head = current_q[:, head_id, :].unsqueeze(1) - current_q_succ_head = \ - current_q_succ[:, head_id, :].unsqueeze(1) - current_q_inter_head = \ - current_q_inter[:, head_id, :].unsqueeze(1) - current_q_succ_head_critical = \ - current_q_succ_critical[:, head_id, :].unsqueeze(1) - current_q_inter_head_critical = \ - current_q_inter_critical[:, head_id, :].unsqueeze(1) - if block_table is not None: - current_k_head = current_k[..., head_id // - group_size, :].unsqueeze(2) - current_v_head = current_v[..., head_id // - group_size, :].unsqueeze(2) - - else: - current_k_head = current_k[:, head_id, :].unsqueeze(1) - current_v_head = current_v[:, head_id, :].unsqueeze(1) - - current_out = self._dual_chunk_flash_attn_prefill_func( - current_q_head, - current_q_succ_head, - current_q_inter_head, - current_q_succ_head_critical, - current_q_inter_head_critical, - current_k_head, - current_v_head, - current_block_table, - softmax_scale, - chunk_size, - local_size, - scaling_factor[i].item(), - ke - ks, - sparse_attn_enabled=sparse_attn_enabled, - ) - current_output[:, head_id:head_id + 1, :] = current_out - all_outputs.append(current_output) - return torch.cat(all_outputs, dim=0) - - def _dual_chunk_flash_attn_prefill_func( - self, - q, - q_succ, - q_inter, - q_succ_critical, - q_inter_critical, - k, - v, - block_table, - softmax_scale: float, - chunk_size: int, - local_size: int, - scaling_factor: float, - k_length: int, - sparse_attn_enabled: Optional[bool] = True, - heads_vertical_size=None, - heads_slash_size=None, - group_size=None, - ): - flash_results = [] - chunk_len = chunk_size - local_size - - if block_table is not None: - block_size = v.shape[1] - if chunk_len % block_size != 0: - raise ValueError("chunk_len must be divisible by block_size.") - else: - block_size = 1 - - if self.original_max_position_embeddings > 0: - softmax_scale = softmax_scale * scaling_factor - - begin = k_length - q.shape[0] - while begin < k_length: - flash_per_chunk = [] - - prev_chunk_end_pos = (begin // chunk_len) * chunk_len - next_chunk_end_pos = prev_chunk_end_pos + chunk_len - end = min(next_chunk_end_pos, k_length) - qbegin = begin - (k_length - q.shape[0]) - qend = end - (k_length - q.shape[0]) - - qk_chunks = [] - q_states_intra = q[qbegin:qend] - # choose critical token - if block_table is not None: - block_tables_intra = _get_block(block_table, block_size, - prev_chunk_end_pos, end) - k_states_intra = k[block_tables_intra].view( - -1, *k.shape[-2:])[:(end - prev_chunk_end_pos)] - v_states_intra = v[block_tables_intra].view( - -1, *v.shape[-2:])[:(end - prev_chunk_end_pos)] - else: - block_tables_intra = None - k_states_intra = k[prev_chunk_end_pos:end] - v_states_intra = v[prev_chunk_end_pos:end] - - if sparse_attn_enabled: - last_q_size = min(qend - qbegin, self.sparse_attention_last_q) - _, num_device_k_heads, head_dim = k_states_intra.shape - k_states_intra = (k_states_intra.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, head_dim)) - v_states_intra = (v_states_intra.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, head_dim)) - qk_chunks.append( - (q_states_intra.transpose(0, 1)[:, -last_q_size:] * - softmax_scale) @ k_states_intra.permute(1, 2, 0)) - - if prev_chunk_end_pos - chunk_len >= 0: - q_states_succ = q_succ[qbegin:qend] - q_states_succ_critical = q_succ_critical[qbegin:qend] - if block_table is not None: - block_tables_succ = _get_block( - block_table, block_size, - prev_chunk_end_pos - chunk_len, prev_chunk_end_pos) - k_states_succ = k[block_tables_succ].view( - -1, *k.shape[-2:])[:chunk_len] - v_states_succ = v[block_tables_succ].view( - -1, *v.shape[-2:])[:chunk_len] - else: - k_states_succ = k[prev_chunk_end_pos - - chunk_len:prev_chunk_end_pos] - v_states_succ = v[prev_chunk_end_pos - - chunk_len:prev_chunk_end_pos] - - if sparse_attn_enabled: - k_states_succ = (k_states_succ.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, - head_dim)) - v_states_succ = (v_states_succ.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, - head_dim)) - qk_chunks.append((q_states_succ_critical.transpose( - 0, 1)[:, -last_q_size:] * softmax_scale) - @ k_states_succ.permute(1, 2, 0)) - - if prev_chunk_end_pos - chunk_len * 2 >= 0: - q_states_inter = q_inter[qbegin:qend] - q_states_inter_critical = q_inter_critical[qbegin:qend] - if block_table is not None: - block_tables_inter = _get_block( - block_table, block_size, 0, - prev_chunk_end_pos - chunk_len) - k_states_inter = k[block_tables_inter].view( - -1, *k.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] - v_states_inter = v[block_tables_inter].view( - -1, *v.shape[-2:])[:(prev_chunk_end_pos - chunk_len)] - else: - k_states_inter = k[:prev_chunk_end_pos - chunk_len] - v_states_inter = v[:prev_chunk_end_pos - chunk_len] - - if sparse_attn_enabled: - k_states_inter = (k_states_inter.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, - head_dim)) - v_states_inter = (v_states_inter.unsqueeze(2).repeat( - 1, 1, group_size, - 1).reshape(-1, num_device_k_heads * group_size, - head_dim)) - qk_chunks.append((q_states_inter_critical.transpose( - 0, 1)[:, -last_q_size:] * softmax_scale) - @ k_states_inter.permute(1, 2, 0)) - - if sparse_attn_enabled: - reversed_qk = qk_chunks[::-1] - qk = torch.cat(reversed_qk, dim=-1) - - qk[:, :, -last_q_size:] = torch.where( - self.last_q_mask[..., -last_q_size:, - -last_q_size:].to(qk.device), - qk[:, :, -last_q_size:], -torch.inf) - qk = F.softmax(qk, dim=-1, dtype=torch.float32) - - vertical = qk.sum(-2, keepdim=True) - vertical[..., :30] = torch.inf - - # Avoid sorting by using the min/max ints to fill the indexer - # buffers. - int32_max = torch.iinfo(torch.int32).max - int32_min = torch.iinfo(torch.int32).min - n_heads = qk.size()[0] - max_slash_topk = torch.max(heads_slash_size).item() - max_vertical_topk = torch.max(heads_vertical_size).item() - # store each head's slash topk, vertical topk - vertical = vertical.reshape((n_heads, -1)) - # prevent out of range when prompt size < max_vertical_topk - max_vertical_topk = min(vertical.shape[-1], max_vertical_topk) - vertical_topk_buffer = torch.topk(vertical, max_vertical_topk, - -1).indices - slash_topk_buffer = torch.empty(size=(n_heads, max_slash_topk), - dtype=torch.int64, - device=qk.device) - for head_i in range(n_heads): - # (nqheads=1, lastq, k_len) - head_score = qk[head_i:head_i + 1, :, :] - slash_scores = _sum_all_diagonal_matrix(head_score) - if head_score.size(1) != 1: - # drop right up corner - slash_scores = slash_scores[..., :-last_q_size + 1] - slash_scores[..., -100:] = torch.inf - - head_slash_size = heads_slash_size[head_i] - head_slash_size = min(head_slash_size, vertical.size(-1)) - slash_topk = torch.topk(slash_scores, head_slash_size, - -1).indices - #(nheads, max_topk) - slash_topk_buffer[head_i, :head_slash_size] = slash_topk - - # reset heads topk - heads_slash_size[head_i] = head_slash_size - heads_vertical_size[head_i] = min( - heads_vertical_size[head_i], max_vertical_topk) - - # store - vertical_buffer = torch.full((n_heads, max_vertical_topk), - int32_max, - dtype=torch.int64, - device=q.device) - slash_buffer = torch.full((n_heads, max_slash_topk), - int32_min, - dtype=torch.int64, - device=q.device) - succ_vertical_buffer = torch.full((n_heads, max_vertical_topk), - int32_max, - dtype=torch.int64, - device=q.device) - succ_slash_buffer = torch.full((n_heads, max_slash_topk), - int32_min, - dtype=torch.int64, - device=q.device) - inter_vertical_buffer = torch.full( - (n_heads, max_vertical_topk), - int32_max, - dtype=torch.int64, - device=q.device) - inter_slash_buffer = torch.full((n_heads, max_slash_topk), - int32_min, - dtype=torch.int64, - device=q.device) - - vertical_size_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - slash_sizes_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - succ_vertical_size_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - succ_slash_sizes_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - inter_vertical_size_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - inter_slash_sizes_buffer = torch.empty(size=(n_heads, ), - dtype=torch.int32, - device=q.device) - - for head_i in range(n_heads): - vertical_topk = vertical_topk_buffer[ - head_i, :heads_vertical_size[head_i]] - # intra - intra_vertical_indices = vertical_topk[ - vertical_topk >= - prev_chunk_end_pos] - prev_chunk_end_pos - if intra_vertical_indices.nelement() == 0: - intra_vertical_indices = torch.cat([ - intra_vertical_indices, - torch.arange(0, - k_states_intra.size(0), - max(1, - k_states_intra.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - slash_topk = slash_topk_buffer[ - head_i, :heads_slash_size[head_i]] - intra_slash_indices = ( - (qk.size(-1) - 1) - - slash_topk[slash_topk >= prev_chunk_end_pos]) - # fill buffer - v_count = intra_vertical_indices.nelement() - s_count = intra_slash_indices.nelement() - vertical_size_buffer[head_i] = v_count - slash_sizes_buffer[head_i] = s_count - vertical_buffer[head_i, :v_count].copy_( - intra_vertical_indices) - slash_buffer[head_i, :s_count].copy_(intra_slash_indices) - # succ - if prev_chunk_end_pos - chunk_len >= 0: - succ_vertical_indices = vertical_topk[ - (vertical_topk < prev_chunk_end_pos) - & (vertical_topk >= prev_chunk_end_pos - - chunk_len)] - (prev_chunk_end_pos - chunk_len) - # TODO: support no vertical - if succ_vertical_indices.nelement() == 0: - succ_vertical_indices = torch.cat([ - succ_vertical_indices, - torch.arange( - 0, - k_states_succ.size(0), - max(1, - k_states_succ.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - succ_slash_indices = ( - (prev_chunk_end_pos + (qend - qbegin) - 1) - - slash_topk[((slash_topk >= - (prev_chunk_end_pos - chunk_len)) & - (slash_topk < (prev_chunk_end_pos + - (qend - qbegin))))]) - if succ_slash_indices.nelement() == 0: - succ_slash_indices = torch.cat([ - succ_slash_indices, - torch.arange( - 0, - k_states_succ.size(0), - max(1, - k_states_succ.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - # fill buffer - v_count = succ_vertical_indices.nelement() - s_count = succ_slash_indices.nelement() - succ_vertical_size_buffer[head_i] = v_count - succ_slash_sizes_buffer[head_i] = s_count - succ_vertical_buffer[head_i, :v_count].copy_( - succ_vertical_indices) - succ_slash_buffer[head_i, :s_count].copy_( - succ_slash_indices) - - if prev_chunk_end_pos - 2 * chunk_len >= 0: - inter_vertical_indices = vertical_topk[ - vertical_topk < prev_chunk_end_pos - chunk_len] - - if inter_vertical_indices.nelement() == 0: - inter_vertical_indices = torch.cat([ - inter_vertical_indices, - torch.arange( - 0, - k_states_inter.size(0), - max(1, - k_states_inter.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - inter_slash_indices = ( - (prev_chunk_end_pos - chunk_len + - (qend - qbegin) - 1) - - slash_topk[slash_topk < (prev_chunk_end_pos - - chunk_len + - (qend - qbegin))]) - if inter_slash_indices.nelement() == 0: - inter_slash_indices = torch.cat([ - inter_slash_indices, - torch.arange( - 0, - k_states_inter.size(0), - max(1, - k_states_inter.size(0) / 5), - dtype=torch.int32, - device=intra_vertical_indices.device) - ]) - # fill buffer - v_count = inter_vertical_indices.nelement() - s_count = inter_slash_indices.nelement() - inter_vertical_size_buffer[head_i] = v_count - inter_slash_sizes_buffer[head_i] = s_count - inter_vertical_buffer[head_i, :v_count].copy_( - inter_vertical_indices) - inter_slash_buffer[head_i, :s_count].copy_( - inter_slash_indices) - else: - intra_vertical_indices, intra_slash_indices = None, None - succ_vertical_indices, succ_slash_indices = None, None - inter_vertical_indices, inter_slash_indices = None, None - - if sparse_attn_enabled: - flash_result = self._do_flash_attn( - q_states_intra, - k_states_intra, - v_states_intra, - softmax_scale=softmax_scale, - causal=True, - stage="intra", - vertical_indices=vertical_buffer, - slash_indices=slash_buffer, - vertical_indices_count=vertical_size_buffer, - slash_indices_count=slash_sizes_buffer, - mergehead_softmax_scale=softmax_scale, - sparse_attn_enabled=sparse_attn_enabled) - else: - flash_result = self._do_flash_attn( - q_states_intra, - k_states_intra, - v_states_intra, - softmax_scale=softmax_scale, - causal=True, - stage="intra", - vertical_indices=intra_vertical_indices, - slash_indices=intra_slash_indices, - sparse_attn_enabled=sparse_attn_enabled) - flash_per_chunk.append(flash_result) - - if prev_chunk_end_pos - chunk_len >= 0: - if sparse_attn_enabled: - flash_result = self._do_flash_attn( - q_states_succ, - k_states_succ, - v_states_succ, - softmax_scale=softmax_scale, - causal=False, - stage="succ", - vertical_indices=succ_vertical_buffer, - slash_indices=succ_slash_buffer, - vertical_indices_count=succ_vertical_size_buffer, - slash_indices_count=succ_slash_sizes_buffer, - mergehead_softmax_scale=softmax_scale, - sparse_attn_enabled=sparse_attn_enabled) - else: - flash_result = self._do_flash_attn( - q_states_succ, - k_states_succ, - v_states_succ, - softmax_scale=softmax_scale, - causal=False, - stage="succ", - vertical_indices=succ_vertical_indices, - slash_indices=succ_slash_indices, - sparse_attn_enabled=sparse_attn_enabled) - flash_per_chunk.append(flash_result) - - if prev_chunk_end_pos - chunk_len * 2 >= 0: - if sparse_attn_enabled: - flash_result = self._do_flash_attn( - q_states_inter, - k_states_inter, - v_states_inter, - softmax_scale=softmax_scale, - causal=False, - stage="inter", - vertical_indices=inter_vertical_buffer, - slash_indices=inter_slash_buffer, - vertical_indices_count=inter_vertical_size_buffer, - slash_indices_count=inter_slash_sizes_buffer, - mergehead_softmax_scale=softmax_scale, - sparse_attn_enabled=sparse_attn_enabled) - else: - flash_result = self._do_flash_attn( - q_states_inter, - k_states_inter, - v_states_inter, - softmax_scale=softmax_scale, - causal=False, - stage="inter", - vertical_indices=inter_vertical_indices, - slash_indices=inter_slash_indices, - sparse_attn_enabled=sparse_attn_enabled) - flash_per_chunk.append(flash_result) - - flash_results.append(flash_per_chunk) - begin = end - - attn_output = self._merge_attn_outputs(flash_results) - del flash_results - return attn_output - - def _do_flash_attn( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - softmax_scale: float, - causal: bool = True, - max_seqlen_k: Optional[int] = None, - stage: str = "intra", - vertical_indices: Optional[torch.Tensor] = None, - slash_indices: Optional[torch.Tensor] = None, - vertical_indices_count: Optional[torch.Tensor] = None, - slash_indices_count: Optional[torch.Tensor] = None, - mergehead_softmax_scale: Optional[float] = None, - sparse_attn_enabled: Optional[bool] = False, - ): - if max_seqlen_k is None: - max_seqlen_k = key_states.shape[0] - - q_len = query_states.shape[0] - q_heads = query_states.shape[1] - h_dim = query_states.shape[-1] - - if sparse_attn_enabled: - assert slash_indices is not None - if stage == "intra": - assert causal - else: - assert not causal - - query_states = query_states.unsqueeze(0).transpose(1, 2) - key_states = key_states.unsqueeze(0).transpose(1, 2) - value_states = value_states.unsqueeze(0).transpose(1, 2) - - q = query_states - k = key_states - v = value_states - - if (vertical_indices_count is not None and \ - slash_indices_count is not None): - assert mergehead_softmax_scale is not None - - res, s_lse = _vertical_slash_sparse_attention( - q, - k, - v, - vertical_indices, - slash_indices, - mergehead_softmax_scale, - causal=causal, - stage=stage, - vertical_indices_count=vertical_indices_count, - slash_indices_count=slash_indices_count) - res = res.view(q_heads, q_len, - h_dim).transpose(0, 1) # (qlen,nhead,h_dim) - s_lse = s_lse.view( - q_heads, q_len, - 1).squeeze(-1).unsqueeze(0).float() # (1, nhead,qlen) - else: - res, s_lse = _vertical_slash_sparse_attention(q, - k, - v, - vertical_indices, - slash_indices, - softmax_scale, - causal=causal, - stage=stage) - res = res.view(q_len, q_heads, h_dim) - s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float() - return res, s_lse - - output, softmax_lse = flash_attn_varlen_func( - q=query_states, - k=key_states, - v=value_states, - softmax_scale=softmax_scale, - cu_seqlens_q=torch.tensor([0, query_states.shape[0]], - dtype=torch.int32, - device=query_states.device), - max_seqlen_q=query_states.shape[0], - cu_seqlens_k=torch.tensor([0, max_seqlen_k], - dtype=torch.int32, - device=query_states.device), - max_seqlen_k=max_seqlen_k, - causal=causal, - return_softmax_lse=True, - ) - softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, - 2).float() - return output, softmax_lse - - def _merge_attn_outputs( - self, - flash_results: List[List[Tuple[torch.Tensor, torch.Tensor]]], - return_lse: Optional[bool] = False, - ) -> torch.Tensor: - attn_outputs_all = [] - logits_all = [] - - for flash_per_chunk in flash_results: - if len(flash_per_chunk) == 1: - attn_outputs_all.append(flash_per_chunk[0][0]) - if return_lse: - logits_all.append(flash_per_chunk[0][1]) - continue - - attn_outputs = torch.stack([ - flash_attn_output[0] for flash_attn_output in flash_per_chunk - ]) - logits = torch.stack([ - flash_attn_output[1] for flash_attn_output in flash_per_chunk - ]) - logits = logits.to(torch.float32) - - if return_lse: - max_val = torch.max(logits, dim=0).values - diff = torch.abs(logits[0] - logits[1]) - log_sum_exp = max_val + torch.log1p(torch.exp(-diff)) - logits_all.append(log_sum_exp) - - max_logits = torch.max(logits, dim=0).values - stable_logits = logits - max_logits.unsqueeze(0) - lse_s = torch.exp(stable_logits).detach() - lse_sum = torch.sum(lse_s, dim=0) - lse_s /= lse_sum - attn_outputs *= lse_s.unsqueeze(-1).transpose(2, 3).squeeze(1) - attn_outputs_all.append(attn_outputs.sum(dim=0)) - - if return_lse: - return (torch.cat(attn_outputs_all, - dim=0), torch.cat(logits_all, dim=-1)) - else: - return torch.cat(attn_outputs_all, dim=0) - - def _dual_chunk_flash_attn_decoding( - self, - query: torch.Tensor, - query_succ: torch.Tensor, - query_inter: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, - softmax_scale: float, - causal: bool, - alibi_slopes: Optional[torch.Tensor], - chunk_size: int, - local_size: int, - original_max_position_embeddings: int, - decode_meta: DualChunkFlashAttentionMetadata, - ): - if not causal: - raise ValueError( - "Dual Chunk Attention does not support causal=False") - - block_size = value_cache.shape[1] - chunk_len = chunk_size - local_size - if chunk_len % block_size != 0: - raise ValueError("chunk_len must be divisible by block_size.") - if original_max_position_embeddings > 0: - assert decode_meta.scaling_factor is not None - scaling_factor = decode_meta.scaling_factor - query = (query * scaling_factor.view(-1, 1, 1, 1)).to( - query.dtype - ) # possible for numerical issue, need to fused in the kernel - query_succ = (query_succ * scaling_factor.view(-1, 1, 1, 1)).to( - query.dtype) - query_inter = (query_inter * scaling_factor.view(-1, 1, 1, 1)).to( - query.dtype) - outputs_list = [] - softmax_lses_list = [] - - # intra-attention - intra_output, intra_softmax_lse = ( - self._dual_chunk_flash_attn_decoding_with_exp_sums( - query, - key_cache, - value_cache, - decode_meta.block_tables_intra, - decode_meta.seq_lens_intra, - softmax_scale, - alibi_slopes, - causal=False, - )) - outputs_list.append(intra_output) - softmax_lses_list.append(intra_softmax_lse) - - # succ-attention - if decode_meta.max_seq_len_succ: - succ_output, succ_softmax_lse = ( - self._dual_chunk_flash_attn_decoding_with_exp_sums( - query_succ, - key_cache, - value_cache, - decode_meta.block_tables_succ, - decode_meta.seq_lens_succ, - softmax_scale, - alibi_slopes, - causal=False, - )) - outputs_list.append(succ_output) - softmax_lses_list.append(succ_softmax_lse) - - # inter-attention - if decode_meta.max_seq_len_inter: - inter_output, inter_softmax_lse = ( - self._dual_chunk_flash_attn_decoding_with_exp_sums( - query_inter, - key_cache, - value_cache, - block_table[:, :decode_meta.max_seq_len_inter], - decode_meta.seq_lens_inter, - softmax_scale, - alibi_slopes, - causal=False, - )) - outputs_list.append(inter_output) - softmax_lses_list.append(inter_softmax_lse) - outputs = torch.stack(outputs_list, dim=0) - del outputs_list - softmax_lses = torch.stack(softmax_lses_list, dim=0).to(torch.float32) - del softmax_lses_list - max_logits = torch.max(softmax_lses, dim=0).values - stable_logits = softmax_lses - max_logits.unsqueeze(0) - lse_s = torch.exp(stable_logits).detach() - lse_sum = torch.sum(lse_s, dim=0) - lse_s /= lse_sum - outputs *= lse_s.unsqueeze(-1).transpose(2, 3) - return outputs.sum(0) - - def _dual_chunk_flash_attn_decoding_with_exp_sums( - self, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, - softmax_scale: float, - alibi_slopes: Optional[torch.Tensor], - causal: bool, - ): - out, softmax_lse = flash_attn_with_kvcache( - q=query, - k_cache=key_cache, - v_cache=value_cache, - block_table=block_table, - cache_seqlens=cache_seqlens, - softmax_scale=softmax_scale, - alibi_slopes=alibi_slopes, - causal=causal, - return_softmax_lse=True, - ) - mask = (cache_seqlens == 0) - out[mask] = 0 - softmax_lse[mask] = -float("inf") - return out, softmax_lse - - -def _vertical_slash_sparse_attention( - query: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD] - key: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] - value: torch.Tensor, # [BATCH, N_HEADS, N_KV_CTX, D_HEAD] - v_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] - s_idx: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] - softmax_scale: float, - causal: bool = True, - stage: str = "intra", - block_size_M: int = 64, - block_size_N: int = 64, - vertical_indices_count: torch.Tensor = None, # [N_HEADS,] - slash_indices_count: torch.Tensor = None, -): - if stage == "intra": - assert causal - else: - assert not causal - - batch_size, num_heads, context_size, head_dim = query.shape - _, _, kv_seq_len, _ = key.shape - - if head_dim not in [16, 32, 64, 128, 256, 512]: - target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim - query = F.pad(query, [0, target_dim, 0, 0, 0, 0, 0, 0]) - key = F.pad(key, [0, target_dim, 0, 0, 0, 0, 0, 0]) - value = F.pad(value, [0, target_dim, 0, 0, 0, 0, 0, 0]) - - v_idx = v_idx.to(torch.int32).reshape( - (batch_size, num_heads, -1)).sort(dim=-1, descending=False)[0] - s_idx = s_idx.to(torch.int32).reshape( - (batch_size, num_heads, -1)).sort(dim=-1, descending=True)[0] - q_seqlens = torch.tensor([context_size], - dtype=torch.int32, - device=query.device) - kv_seqlens = torch.tensor([kv_seq_len], - dtype=torch.int32, - device=query.device) - - if vertical_indices_count is not None and slash_indices_count is not None: - ( - block_count, - block_offset, - column_count, - column_index, - ) = ops.convert_vertical_slash_indexes_mergehead( - q_seqlens, kv_seqlens, v_idx, s_idx, vertical_indices_count, - slash_indices_count, context_size, block_size_M, block_size_N, - causal) - else: - ( - block_count, - block_offset, - column_count, - column_index, - ) = ops.convert_vertical_slash_indexes(q_seqlens, kv_seqlens, v_idx, - s_idx, context_size, - block_size_M, block_size_N, - causal) - - q = query.transpose(1, 2).contiguous() - k = key.transpose(1, 2).contiguous() - v = value.transpose(1, 2).contiguous() - out, lse = sparse_attn_func( - q, - k, - v, - block_count, - block_offset, - column_count, - column_index, - causal=causal, - softmax_scale=softmax_scale, - return_softmax_lse=True, - ) - out = out.transpose(1, 2).contiguous() - softmax_lse = lse.reshape(*lse.shape, 1) - return (out[..., :context_size, :head_dim], - softmax_lse[..., :context_size, :]) - - -def _sum_all_diagonal_matrix(mat: torch.tensor): - h, n, m = mat.shape - # Zero matrix used for padding - zero_mat = torch.zeros((h, n, n), device=mat.device) - # pads the matrix on left and right - mat_padded = torch.cat((zero_mat, mat, zero_mat), -1) - # Change the strides - mat_strided = mat_padded.as_strided((1, n, n + m), - (n * (2 * n + m), 2 * n + m + 1, 1)) - # Sums the resulting matrix's columns - sum_diags = torch.sum(mat_strided, 1) - return sum_diags[:, 1:] # drop left bottom corner - - -def _get_block(block_table: torch.Tensor, block_size: int, begin: int, - end: int): - begin_block = begin // block_size - end_block = (end - 1) // block_size + 1 - return block_table[begin_block:end_block] diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py deleted file mode 100755 index edb3afb4aa07..000000000000 --- a/vllm/attention/backends/flash_attn.py +++ /dev/null @@ -1,929 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with FlashAttention.""" -from collections import defaultdict -from dataclasses import dataclass -from itertools import accumulate -from typing import Dict, List, Optional, Tuple, Type - -import torch - -from vllm import _custom_ops as ops -# yapf conflicts with isort for this block -# yapf: disable -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionType, - is_quantized_kv_cache) -# yapf: enable -from vllm.attention.backends.utils import ( - PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, - compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, - get_seq_len_block_table_args, is_all_cross_attn_metadata_set, - is_all_encoder_attn_metadata_set, is_block_tables_empty) -from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, - get_flash_attn_version) -from vllm.logger import init_logger -from vllm.multimodal import MultiModalPlaceholderMap -from vllm.utils import async_tensor_h2d, make_tensor_with_pad -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache) - -logger = init_logger(__name__) - - -class FlashAttentionBackend(AttentionBackend): - - accept_output_buffer: bool = True - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] - - @staticmethod - def get_name() -> str: - return "FLASH_ATTN" - - @staticmethod - def get_impl_cls() -> Type["FlashAttentionImpl"]: - return FlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return FlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: - return FlashAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - - ops.copy_blocks(key_caches, value_caches, src_to_dists) - - -@dataclass -class FlashAttentionMetadata(AttentionMetadata): - """Metadata for FlashAttentionBackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - - use_cuda_graph: bool - - # Maximum query length in the batch. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - encoder_seq_start_loc: Optional[torch.Tensor] = None - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - @property - def is_all_encoder_attn_metadata_set(self): - ''' - All attention metadata required for encoder attention is set. - ''' - return is_all_encoder_attn_metadata_set(self) - - @property - def is_all_cross_attn_metadata_set(self): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return is_all_cross_attn_metadata_set(self) - - @property - def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert ((self.seq_lens is not None) - or (self.encoder_seq_lens is not None)) - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - - self._cached_prefill_metadata = FlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - encoder_seq_start_loc=self.encoder_seq_start_loc, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - - self._cached_decode_metadata = FlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=self.max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - encoder_seq_start_loc=self.encoder_seq_start_loc, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_decode_metadata - - -class FlashAttentionMetadataBuilder( - AttentionMetadataBuilder[FlashAttentionMetadata]): - - def __init__(self, input_builder): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, - prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - - if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def _get_graph_runner_block_tables( - self, num_seqs: int, - block_tables: List[List[int]]) -> torch.Tensor: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs - - graph_block_tables = self.runner.graph_block_tables[:num_seqs] - for i, block_table in enumerate(block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - graph_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - graph_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - num_seqs = len(seq_lens) - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - self.num_prefill_tokens - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - - return FlashAttentionMetadata( - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - seq_lens=seq_lens, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=True, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) - - -class FlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "FLASH_ATTN backend.") - if use_irope: - logger.warning( - "Using irope in V0 is not supported yet, it will fall back " - "to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window - 1, - 0) if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - self.vllm_flash_attn_version = get_flash_attn_version( - requires_alibi=self.alibi_slopes is not None) - if is_quantized_kv_cache(self.kv_cache_dtype) and ( - not self.kv_cache_dtype.startswith("fp8") - or not flash_attn_supports_fp8()): - raise NotImplementedError( - f"FlashAttention does not support {self.kv_cache_dtype} " - "kv-cache on this device " - f"(FA supports fp8 = {flash_attn_supports_fp8()}).") - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() - if head_size not in support_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by FlashAttention. " - f"Supported head sizes are: {support_head_sizes}.") - self.attn_type = attn_type - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: FlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention. - - Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - output: shape = [num_tokens, num_heads, head_size] - kv_cache: KV cache tensor with shape - [2, num_blocks, block_size, num_kv_heads, head_size]. - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - NOTE: It in-place updates the output tensor. - NOTE: FP8 quantization, flash-attn expect the size of - {q,k,v}_descale to be (num_sequences, num_kv_heads). - We use torch's .expand() to avoid duplicating values - """ - assert output is not None, "Output tensor must be provided." - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for FlashAttentionImpl") - - # NOTE(woosuk): FlashAttention2 does not support FP8 KV cache. - if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16: - assert ( - layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), ( - "key/v_scale is only supported in FlashAttention 3 with " - "base dtype bfloat16") - - attn_type = self.attn_type - if (attn_type == AttentionType.ENCODER - and (not attn_metadata.is_all_encoder_attn_metadata_set)): - raise AttributeError("Encoder attention requires setting " - "encoder metadata attributes.") - elif (attn_type == AttentionType.ENCODER_DECODER - and (not attn_metadata.is_all_cross_attn_metadata_set)): - raise AttributeError("Encoder/decoder cross-attention " - "requires setting cross-attention " - "metadata attributes.") - - kv_cache_dtype: str = self.kv_cache_dtype - softmax_scale: float = self.scale - window_size = self.sliding_window - alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes - logits_soft_cap: Optional[float] = self.logits_soft_cap - fp8_attention = kv_cache_dtype.startswith("fp8") - - if fp8_attention and not flash_attn_supports_fp8(): - raise NotImplementedError( - "FlashAttention does not support FP8 kv-cache on this device.") - - if kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - # We skip updating the KV cache under two conditions: - # a. When the Attention Type is ENCODER. In this phase, we compute - # only the encoder attention without updating the cache. - # b. When both Key and Value are None. This occurs during - # cross-attention computation in the decoding phase, where the - # KV cache is already populated with the cross-attention - # tensor. Thus, we skip cache updates during this time. - if (attn_type != AttentionType.ENCODER) and (key is not None) and ( - value is not None): - if attn_type == AttentionType.ENCODER_DECODER: - # Update cross-attention KV cache (prefill-only) - updated_slot_mapping = attn_metadata.cross_slot_mapping - else: - # Update self-attention KV cache (prefill/decode) - updated_slot_mapping = attn_metadata.slot_mapping - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory - # profiling run. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - updated_slot_mapping.flatten(), # type: ignore[union-attr] - kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if fp8_attention: - kv_cache = kv_cache.view(torch.float8_e4m3fn) - key_cache = key_cache.view(torch.float8_e4m3fn) - value_cache = value_cache.view(torch.float8_e4m3fn) - - if fp8_attention: - num_tokens, num_heads, head_size = query.shape - query, _ = ops.scaled_fp8_quant( - query.reshape( - (num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale) - query = query.reshape((num_tokens, num_heads, head_size)) - - (num_prefill_query_tokens, num_prefill_kv_tokens, - num_decode_query_tokens) = \ - get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) - decode_query = query[num_prefill_query_tokens:] - decode_output = output[num_prefill_query_tokens:] - # QKV for prefill. - query = query[:num_prefill_query_tokens] - prefill_output = output[:num_prefill_query_tokens] - assert query.shape[0] == num_prefill_query_tokens - assert decode_query.shape[0] == num_decode_query_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ - _get_query_key_seq_metadata(prefill_meta, True, attn_type) - - key = key[:num_prefill_kv_tokens] - value = value[:num_prefill_kv_tokens] - - if fp8_attention: - num_kv_tokens, num_kv_heads, head_size = key.shape - - key, _ = ops.scaled_fp8_quant( - key.reshape((num_kv_tokens, - num_kv_heads * head_size)).contiguous(), - layer._k_scale) - key = key.reshape((num_kv_tokens, num_kv_heads, head_size)) - - value, _ = ops.scaled_fp8_quant( - value.reshape((num_kv_tokens, - num_kv_heads * head_size)).contiguous(), - layer._v_scale) - value = value.reshape( - (num_kv_tokens, num_kv_heads, head_size)) - - descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1]) - flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=q_seq_start_loc, - cu_seqlens_k=k_seq_start_loc, - max_seqlen_q=q_seq_len, - max_seqlen_k=k_seq_len, - softmax_scale=softmax_scale, - causal=_get_causal_option(attn_type), - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - out=prefill_output, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - else: - # prefix-enabled attention - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support prefix caching") - assert prefill_meta.seq_lens is not None - assert prefill_meta.query_start_loc is not None - max_seq_len = max(prefill_meta.seq_lens) - descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, - key.shape[1]) - flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - seqused_k=prefill_meta.seq_lens_tensor, - max_seqlen_k=max_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=logits_soft_cap, - out=prefill_output, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - # Use flash_attn_varlen_func kernel for speculative decoding - # because different queries might have different lengths. - - assert decode_meta.max_decode_query_len is not None - # use only for actual varlen decoding - if decode_meta.max_decode_query_len > 1: - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support max_decode_query_len > 1" - ) - assert decode_meta.query_start_loc is not None - descale_shape = (decode_meta.query_start_loc.shape[0] - 1, - key.shape[1]) - flash_attn_varlen_func( - q=decode_query, - k=key_cache, - v=value_cache, - cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_q=decode_meta.max_decode_query_len, - seqused_k=decode_meta.seq_lens_tensor, - max_seqlen_k=decode_meta.max_decode_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - block_table=decode_meta.block_tables, - out=decode_output, - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - else: - # Use flash_attn_with_kvcache for normal decoding. - ( - seq_lens_arg, - _, - block_tables_arg, - ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2]) - flash_attn_with_kvcache( - q=decode_query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - block_table=block_tables_arg, - cache_seqlens=seq_lens_arg, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - out=decode_output.unsqueeze(1), - fa_version=self.vllm_flash_attn_version, - q_descale=layer._q_scale.expand(descale_shape), - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - return output - - -def _get_query_key_seq_metadata( - attn_metadata: FlashAttentionMetadata, - is_prompt: bool, - attn_type: str, -) -> tuple: - """ - Returns sequence metadata for key and query based on the specified - attention type and whether input is a prompt. - - This function computes the starting locations and maximum sequence lengths - for key and query sequences for different attention types. - - Args: - attn_metadata: The attention metadata object - is_prompt (bool): A flag indicating if the input is a prompt - attn_type (AttentionType): The type of attention being used. - - Returns: - tuple: A tuple containing four integers: - - Starting location for the query sequence. - - Maximum sequence length for the query sequence. - - Starting location for the key sequence. - - Maximum sequence length for the key sequence. - - Raises: - AttributeError: If an invalid attention type is provided. - """ - if attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_start_loc, max_seq_len, - attn_metadata.seq_start_loc, max_seq_len) - - elif attn_type == AttentionType.ENCODER_DECODER: - # This is cross attention between the where the key - # is the precomputed encoder attention and query - # is the input sequence. - # Choose query max length based on whether it is prompt - # or not. - if is_prompt: - max_seq_len = attn_metadata.max_prefill_seq_len - else: - max_seq_len = attn_metadata.max_decode_seq_len - return (attn_metadata.seq_start_loc, max_seq_len, - attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len) - elif attn_type == AttentionType.ENCODER: - # For encoder attention both the query and the key are same i.e. the - # encoder sequence. - return (attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len, - attn_metadata.encoder_seq_start_loc, - attn_metadata.max_encoder_seq_len) - elif attn_type == AttentionType.ENCODER_ONLY: - assert is_prompt, "Should not have decode for encoder only model." - return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, - attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -def _get_causal_option(attn_type: str) -> bool: - """ - Determine whether the given attention type is suitable for causal - attention mechanisms. - - Args: - attn_type (AttentionType): The type of attention being evaluated - - Returns: - bool: Returns `True` if the attention type is suitable for causal - attention (i.e., not encoder, encoder-only, or encoder-decoder), - otherwise returns `False`. - """ - return not (attn_type == AttentionType.ENCODER - or attn_type == AttentionType.ENCODER_ONLY - or attn_type == AttentionType.ENCODER_DECODER) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py deleted file mode 100644 index aeaa0ab631cf..000000000000 --- a/vllm/attention/backends/flashmla.py +++ /dev/null @@ -1,227 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from contextlib import contextmanager -from dataclasses import dataclass -from typing import List, Optional, Tuple, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder, - MLACommonState) -from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, - get_mla_metadata, - is_flashmla_supported) - - -class FlashMLABackend(MLACommonBackend): - - @staticmethod - def get_name() -> str: - return "FLASHMLA" - - @staticmethod - def get_impl_cls() -> Type["FlashMLAImpl"]: - return FlashMLAImpl - - @staticmethod - def get_metadata_cls() -> Type["FlashMLAMetadata"]: - return FlashMLAMetadata - - @staticmethod - def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: - return FlashMLAMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["FlashMLAState"]: - return FlashMLAState - - -@dataclass -class FlashMLAMetadata(MLACommonMetadata): - decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, - torch.Tensor]] = None - decode_num_splits: Optional[torch.Tensor] = None - - @property - def decode_metadata(self): - decode_metadata = super().decode_metadata - # TODO: cache assignment? - if decode_metadata is not None: - decode_metadata.decode_tile_scheduler_metadata=\ - self.decode_tile_scheduler_metadata - decode_metadata.decode_num_splits=\ - self.decode_num_splits - return decode_metadata - - -class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.num_q_heads = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - m = super().build(seq_lens, query_lens, cuda_graph_pad_size, - batch_size) - - if m.num_decode_tokens > 0: - m.decode_tile_scheduler_metadata, m.decode_num_splits = \ - get_mla_metadata( - m.seq_lens_tensor[m.num_prefills:], - self.num_q_heads, - 1, # MQA for the decode path - ) - - return m - - -class FlashMLAState(MLACommonState[FlashMLAMetadata]): - - def __init__(self, *args, **kwds): - super().__init__(*args, **kwds) - - self.num_q_heads = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) - - @contextmanager - def graph_capture(self, max_batch_size: int): - # Run a dummy `get_mla_metadata` so we can get the right shapes - self._graph_decoder_tile_scheduler_metadata, \ - self._graph_decode_num_splits = get_mla_metadata( - torch.ones( - max_batch_size, dtype=torch.int32, device=self.runner.device), - self.num_q_heads, - 1, # MQA for the decode path - ) - - with super().graph_capture(max_batch_size): - yield - - del self._graph_decoder_tile_scheduler_metadata - del self._graph_decode_num_splits - - def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): - metadata = super().graph_capture_get_metadata_for_batch( - batch_size, is_encoder_decoder_model) - assert metadata.num_decode_tokens > 0 - - decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( - self._graph_seq_lens[:batch_size], - self.num_q_heads, - 1, # MQA for the decode path - ) - - self._graph_decoder_tile_scheduler_metadata.copy_( - decoder_tile_scheduler_metadata) - self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits) - - metadata.decode_tile_scheduler_metadata=\ - self._graph_decoder_tile_scheduler_metadata - metadata.decode_num_splits=\ - self._graph_decode_num_splits[:batch_size + 1] - - return metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers = super().get_graph_input_buffers( - attn_metadata, is_encoder_decoder_model) - input_buffers["decode_tile_scheduler_metadata"] = \ - attn_metadata.decode_metadata.decode_tile_scheduler_metadata - input_buffers["decode_num_splits"] = \ - attn_metadata.decode_metadata.decode_num_splits - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - super().prepare_graph_input_buffers(input_buffers, attn_metadata, - is_encoder_decoder_model) - - input_buffers["decode_tile_scheduler_metadata"].copy_( - attn_metadata.decode_metadata.decode_tile_scheduler_metadata) - input_buffers["decode_num_splits"].copy_( - attn_metadata.decode_metadata.decode_num_splits) - - -class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str] = None, - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - - is_supported, reason = is_flashmla_supported() - assert is_supported, reason - - unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] - if any(unsupported_features): - raise NotImplementedError( - "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashMLAImpl") - - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FlashMLA with FP8 KV cache not yet supported") - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: FlashMLAMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - - q = torch.cat([q_nope, q_pe], dim=-1)\ - .unsqueeze(1) # Add seqlen dim of 1 (decode) - - o, _ = flash_mla_with_kvcache( - q=q, - k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata, - num_splits=decode_meta.decode_num_splits, - softmax_scale=self.scale, - causal=True, - ) - - return self._v_up_proj(o) diff --git a/vllm/attention/backends/mla/__init__.py b/vllm/attention/backends/mla/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py deleted file mode 100644 index 826b63e1ccda..000000000000 --- a/vllm/attention/backends/mla/common.py +++ /dev/null @@ -1,1305 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -# MLA Common Components - -This file implements common components for MLA implementations. - -First we define: - -Sq as Q sequence length -Skv as KV sequence length - -MLA has two possible ways of computing, a data-movement friendly approach and a -compute friendly approach, we generally want to use the compute friendly -approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) -and the data-movement friendly approach for "decode" (i.e. the ratio -Sq / Skv is "large"). - -NOTE what we deem small and large is currently determined by if its labelled -prefill or decode by the scheduler, but this is something we should probably -tune. - -Main reference: DeepseekV2 paper, and FlashInfer Implementation -(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - -Deepseek's MLA attention works the following way: -* Use a single latent vector to represent the per-token entry of the KV cache. -* For decode (i.e. the memory friendly approach) the attention "simulates" a -multi-head attention, while the compute is similar to multi-query attention. - -Below is example of both paths assuming batchsize = 1 - -## More Extent Definitions: - -C Context length, `Skv - Sq` -H hidden size -N number of attention heads -Lq latent dimension for Q 1536 in DSV3 -Lkv latent dimension for K/V 512 in DSV3 -P nope dimension, no rope. 128 in DSV3 -R rope dimension, goes through rope. 64 in DSV3 -V V head dim. 128 in DSV3 - -## Vector/Matrix Definitions - -h_t hidden states (input to attention) shape [Sq, H] -q_c latent/compressed Q shape [Sq, Lq] -q_nope uncompressed Q (no-rope) shape [Sq, N, P] -q_pe uncompressed Q (rope) shape [Sq, N, R] -kv_c latent/compressed KV shape [Skv, Lkv] -k_pe decoupled k position embeddings shape [Skv, R] -new_kv_c new kv_c from current iter shape [Sq, Lkv] -new_k_pe new k_pe from current iter shape [Sq, R] -cache_kv_c cached k_c from previous iters shape [C, Lkv] -cache_k_pe cached k_pe from previous iters shape [C, R] -W_DQ project h_t to q_c shape [H, Lq] -W_UQ project q_c to q_nope shape [Lq, N * P] -W_QR project q_c to q_pe shape [Lq, N * R] -W_DKV project h_t to kv_c shape [H, Lkv] -W_UK project kv_c to k_nope shape [Lkv, N, P] -W_KR project h_t to k_pe shape [H, R] -W_UV project kv_c to v shape [Lkv, N, V] -W_O project v to h_t shape [N * V, H] - - -## Compute Friendly Approach (i.e. "_forward_prefill"): - -q_c = h_t @ W_DQ -q_nope = (q_c @ W_UQ).view(Sq, N, P) -q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) -new_kv_c = h_t @ W_DKV -new_k_pe = RoPE(h_t @ W_KR) -kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) -k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) -k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) -v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V) - -// MHA with QK headdim = P + R -// V headdim = V -// spda_o shape [Sq, N, V] -spda_o = scaled_dot_product_attention( - torch.cat([q_nope, q_pe], dim=-1), - torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), - v -) -return spda_o @ W_O - -NOTE: in the actual code, - `kv_b_proj` is [W_UK; W_UV] concatenated per head - `q_b_proj` is [W_UQ; W_QR] concatenated per head - `out_proj` is W_O - - -## Data-Movement Friendly Approach (i.e. "_forward_decode"): - -Runtime -q_c = h_t @ W_DQ -q_nope = (q_c @ W_UQ).view(-1, N, P) -ql_nope = einsum("snh,lnh->snl", q, W_UK) -q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) -new_kv_c = h_t @ W_DKV -new_k_pe = RoPE(h_t @ W_KR) -kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) -k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) - -// MQA with QK headdim = Lkv + R -// V headdim = Lkv -// spda_o shape [Sq, N, Lkv] -// NOTE: this is less compute-friendly since Lkv > P -// but is more data-movement friendly since its MQA vs MHA -spda_o = scaled_dot_product_attention( - torch.cat([ql_nope, q_pe], dim=-1), - torch.cat([kv_c, k_pe], dim=-1), - kv_c -) - -o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) -return o.view(-1, N * V) @ self.num_heads @ W_O - - -## Chunked Prefill - -For chunked prefill we want to use the compute friendly algorithm. We are -assuming sufficiently large Sq / Skv ratio, in the future may want to switch to -the data-movement friendly approach if the chunk (i.e. `Sq`) is small. - -However, the compute-friendly approach can potentially run out of memory if Skv -is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` - -To mitigate this, we chunk the computation of attention with respect to the -current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a -fixed workspace size. - -The chunked prefill approach is as follows: - -MCC Max chunk of context to process per iter, computed dynamically, - used to bound the memory usage - -q_c = h_t @ W_DQ -q_nope = (q_c @ W_UQ).view(Sq, N, P) -q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) -new_kv_c = h_t @ W_DKV -new_k_pe = RoPE(h_t @ W_KR) -new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) -new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V) - -// MHA between queries and new KV -// with QK headdim = P + R -// V headdim = V -// curr_o shape [Sq, N, V] -// curr_lse shape [N, Sq], this is just order FA returns -curr_o, curr_lse = scaled_dot_product_attention( - torch.cat([q_nope, q_pe], dim=-1), - torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), - new_v, - casual=True, - return_softmax_lse=True -) - -// Compute attention with the already existing context -for chunk_idx in range(cdiv(C, MCC)): - chunk_start = chunk_idx * MCC - chunk_end = min(chunk_start + MCC, C) - Sc = chunk_end - chunk_start - cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end] - cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] - cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) - cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) - - chunk_o, chunk_lse = scaled_dot_product_attention( - torch.cat([q_nope, q_pe], dim=-1), - torch.cat([cache_k_nope_chunk, - cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], - dim=-1), - cache_v_chunk, - casual=False, - return_softmax_lse=True - ) - - curr_o, curr_lse = merge_attn_states( - suffix_output=curr_o, - suffix_lse=curr_lse, - prefix_output=chunk_o, - prefix_lse=chunk_lse, - ) - -return curr_o @ W_O -""" - -import functools -from abc import abstractmethod -from collections import defaultdict -from contextlib import contextmanager -from dataclasses import dataclass -from itertools import accumulate -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar - -import torch - -from vllm import _custom_ops as ops -from vllm import envs -from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionState, MLAAttentionImpl) -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) -from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.attention.utils.fa_utils import get_flash_attn_version -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearBase, - UnquantizedLinearMethod) -from vllm.multimodal import MultiModalPlaceholderMap -from vllm.platforms import current_platform -from vllm.triton_utils import HAS_TRITON -from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down - -if HAS_TRITON: - from vllm.attention.ops.triton_flash_attention import triton_attention -else: - triton_attention = None - -try: - from vllm.vllm_flash_attn import flash_attn_varlen_func - is_vllm_fa = True -except ImportError: - is_vllm_fa = False - try: - # For rocm use upstream flash attention - from flash_attn import flash_attn_varlen_func - except ImportError: - flash_attn_varlen_func = None - -is_hip = current_platform.is_rocm() - - -class MLACommonBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "TRITON_MLA" - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return MLACommonMetadata - - @staticmethod - def get_builder_cls() -> Type["MLACommonMetadataBuilder"]: - return MLACommonMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["MLACommonState"]: - return MLACommonState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - ) -> Tuple[int, ...]: - return (num_blocks, block_size, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - ops.copy_blocks_mla(kv_caches, src_to_dists) - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [576] - - -T = TypeVar("T", bound="MLACommonMetadata") - - -class MLACommonState(AttentionState, Generic[T]): - - def __init__(self, runner): - self.runner = runner - self._is_graph_capturing = False - - scheduler_config = runner.scheduler_config - self.model_config = runner.model_config - cache_config = runner.cache_config - - self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled - self.enable_prefix_caching = cache_config.enable_prefix_caching - - if self.chunked_prefill_enabled or self.enable_prefix_caching: - self.context_chunk_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max( - 8 * self.model_config.max_model_len, 4 * - scheduler_config.max_num_seqs * cache_config.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.context_chunk_workspace_size >= \ - scheduler_config.max_num_seqs * cache_config.block_size - - @contextmanager - def graph_capture(self, max_batch_size: int): - self._is_graph_capturing = True - - self._graph_slot_mapping = torch.full((max_batch_size, ), - PAD_SLOT_ID, - dtype=torch.long, - device=self.runner.device) - self._graph_seq_lens = torch.ones(max_batch_size, - dtype=torch.int32, - device=self.runner.device) - self._graph_block_tables = torch.from_numpy( - self.runner.graph_block_tables).to(device=self.runner.device) - - self._positions = torch.zeros((max_batch_size, ), - dtype=torch.long, - device=self.runner.device) - - yield - - self._is_graph_capturing = False - del self._graph_slot_mapping - del self._graph_seq_lens - del self._graph_block_tables - del self._positions - - def graph_clone(self, batch_size: int): - assert self._is_graph_capturing - return self.__class__(self.runner) - - def graph_capture_get_metadata_for_batch( - self, - batch_size: int, - is_encoder_decoder_model: bool = False) -> T: - assert self._is_graph_capturing - - attn_metadata = self.runner.attn_backend.make_metadata( - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - use_cuda_graph=True, - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=self._graph_slot_mapping[:batch_size], - seq_lens=None, - seq_lens_tensor=self._graph_seq_lens[:batch_size], - max_query_len=1, - max_decode_query_len=1, - max_prefill_seq_len=0, - max_decode_seq_len=self.runner.max_seq_len_to_capture, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self._graph_block_tables[:batch_size], - head_dim=self.runner.model_config.get_head_size()) - - if is_encoder_decoder_model: - raise NotImplementedError( - "MLACommonState does not support encoder/decoder yet") - - return attn_metadata - - def get_graph_input_buffers(self, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers = { - "slot_mapping": attn_metadata.slot_mapping, - "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, - "block_tables": attn_metadata.decode_metadata.block_tables, - } - if is_encoder_decoder_model: - raise NotImplementedError( - "MLACommonState does not support encoder/decoder yet") - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): - input_buffers["seq_lens_tensor"].copy_( - attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) - input_buffers["block_tables"].copy_( - attn_metadata.decode_metadata.block_tables, non_blocking=True) - if is_encoder_decoder_model: - raise NotImplementedError( - "TritonMLAState does not support encoder/decoder yet") - - def begin_forward(self, model_input): - if self.chunked_prefill_enabled or self.enable_prefix_caching: - if not hasattr(self, "context_chunk_workspace"): - # not self.runner.device does not return the correct device - # for this process, (init_device sets the correct device but - # only on the Worker). The only way Ive figured out to get the - # correct device is to allocate the workspace on the first call - # to begin_forward and use the device of the input tokens - assert model_input.input_tokens is not None - self.context_chunk_workspace = torch.empty( - (self.context_chunk_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=model_input.input_tokens.device, - ) - - model_input.attn_metadata.context_chunk_workspace = \ - self.context_chunk_workspace - - -@dataclass -class MLACommonMetadata(AttentionMetadata): - """Metadata for MLACommon. - - NOTE: Please read the comment at the top of the file before trying to - understand this class - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Maximum query length in the batch. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - _cached_prefill_metadata: Optional[Any] = None - _cached_decode_metadata: Optional[Any] = None - - num_prefill_tokens: int - - # The dimension of the attention heads - head_dim: Optional[int] = None - - # Used when chunked prefill is enabled to simulate worst case workspace - # allocations, hopefully to avoid going OOM - is_profile_run: bool = False - - # New for MLA (compared to FlashAttention) - # For chunked prefill - context_chunk_cu_seq_lens: Optional[torch.Tensor] = None - context_chunk_starts: Optional[torch.Tensor] = None - context_chunk_seq_tot: Optional[List[int]] = None - context_chunk_max_seq_lens: Optional[List[int]] = None - # Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted - context_chunk_workspace: Optional[torch.Tensor] = None - - def __post_init__(self): - supported_head_sizes = MLACommonBackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: - raise ValueError( - f"Only {supported_head_sizes} are supported for head_dim,", - f" received {self.head_dim}.") - - @property - def prefill_metadata(self): - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - - self._cached_prefill_metadata = self.__class__( - # Required by ModelRunner - use_cuda_graph=False, # Not Attention Related - # Required by Attention Metadata - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - # Required by Attention Metadata (not used) - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - # MLACommonMetadata - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_query_len=0, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - head_dim=self.head_dim, - is_profile_run=self.is_profile_run, - # MLACommonMetadata Chunk prefill specific - context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens, - context_chunk_starts=self.context_chunk_starts, - context_chunk_seq_tot=self.context_chunk_seq_tot, - context_chunk_max_seq_lens=self.context_chunk_max_seq_lens, - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self): - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.seq_lens_tensor is not None - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - - self._cached_decode_metadata = self.__class__( - # Required by ModelRunner - use_cuda_graph=self.use_cuda_graph, # Not Attention Related - # Required by Attention Metadata - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - # Required by Attention Metadata (not used) - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - # MLACommonMetadata - seq_lens=None, - seq_lens_tensor=seq_lens_tensor, - max_decode_query_len=self.max_decode_query_len, - max_query_len=self.max_query_len, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, - block_tables=block_tables, - head_dim=self.head_dim, - is_profile_run=self.is_profile_run) - return self._cached_decode_metadata - - -class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - BLOCK_TABLE_EXTENDER: list[list[int]] = [] - - def __init__(self, input_builder): - self.input_builder = input_builder - self.runner = input_builder.runner - self.sliding_window = input_builder.sliding_window - self.block_size = input_builder.block_size - self.chunked_prefill_enabled = \ - self.runner.scheduler_config.chunked_prefill_enabled - self.enable_prefix_caching = \ - self.runner.cache_config.enable_prefix_caching - - if self.chunked_prefill_enabled or self.enable_prefix_caching: - attn_state = self.input_builder.runner.attn_state - self.context_chunk_workspace_size = \ - attn_state.context_chunk_workspace_size - self.page_size = self.runner.block_size - - def prepare(self): - self.slot_mapping: List[int] = [] - self.prefill_seq_lens: List[int] = [] - self.context_lens: List[int] = [] - self.block_tables: List[List[int]] = [] - self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.num_decode_tokens = 0 - self.has_prefix_cache_hit = False - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, - prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - if is_prompt: - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - - def _get_graph_runner_block_tables( - self, num_seqs: int, - block_tables: List[List[int]]) -> torch.Tensor: - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - max_batch_size, max_blocks = self.runner.graph_block_tables.shape - assert max_batch_size >= num_seqs - - graph_block_tables = self.runner.graph_block_tables[:num_seqs] - for i, block_table in enumerate(block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - graph_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - graph_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - return torch.from_numpy(graph_block_tables).to( - device=self.runner.device, non_blocking=True) - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors. - - Args: - seq_lens: The maybe padded sequence lengths of the input sequences. - query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. - """ - prefix_cache_hit = any([ - inter_data.prefix_cache_hit - for inter_data in self.input_builder.inter_data_list - ]) - - for inter_data in self.input_builder.inter_data_list: - self._add_seq_group(inter_data, - self.input_builder.chunked_prefill_enabled, - prefix_cache_hit) - - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - max_query_len = max(query_lens) - decode_query_lens = query_lens[self.num_prefills:] - if len(decode_query_lens) > 0: - max_decode_query_len = max(decode_query_lens) - else: - max_decode_query_len = 1 - max_prefill_seq_len = max(self.prefill_seq_lens, default=0) - max_decode_seq_len = max(self.curr_seq_lens, default=0) - num_decode_tokens = self.num_decode_tokens - query_start_loc = list(accumulate(query_lens, initial=0)) - seq_start_loc = list(accumulate(seq_lens, initial=0)) - - num_seqs = len(seq_lens) - if use_captured_graph: - self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER * - cuda_graph_pad_size) - num_decode_tokens = batch_size - self.num_prefill_tokens - - block_tables = self._get_graph_runner_block_tables( - num_seqs, self.block_tables) - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=0, - dtype=torch.int, - device=device, - ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) - - assert device is not None - context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, - device, self.runner.pin_memory) - seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, - self.runner.pin_memory) - slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, - device, self.runner.pin_memory) - query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, - device, - self.runner.pin_memory) - seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, - device, self.runner.pin_memory) - - context_chunk_cu_seq_lens = None - context_chunk_starts = None - context_chunk_seq_tot = None - context_chunk_max_seq_lens = None - - if (self.chunked_prefill_enabled or self.enable_prefix_caching) \ - and self.num_prefills > 0 \ - and context_lens_tensor is not None \ - and context_lens_tensor[:self.num_prefills].max() > 0: - - # NOTE: it is recommended you read the `Chunked Prefill` section in - # the comment at the top of the file before trying to understand - # the following code - - num_prefills_with_context = \ - (context_lens_tensor[:self.num_prefills] > 0).sum().item() - - # currently we allocate an equal amount of workspace for each - # prefill in the batch, we could probably use a more advanced - # algorithm here and allocate more workspace to prefills with - # longer context lengths - max_context_chunk = \ - self.context_chunk_workspace_size // num_prefills_with_context - - # align max_context_chunk to page_size by rounding down, - # currently the `gather_and_maybe_dequant_cache` kernel cannot - # handle `context_chunk_starts` that are not aligned to page_size - max_context_chunk = round_down(max_context_chunk, self.page_size) - assert max_context_chunk > 0 - num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) - - # if `max_context_chunk = 256`, `num_chunks = 3`, and - # `num_prefills_with_context = 4`, create a tensor that looks like - # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] - context_chunk_starts = \ - torch.arange(num_chunks, device=device, dtype=torch.int32)\ - .unsqueeze(1).expand(-1, self.num_prefills)\ - * max_context_chunk - chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\ - .unsqueeze(0), context_chunk_starts + max_context_chunk) - chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) - _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( - torch.int32) - zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\ - .unsqueeze(-1) - context_chunk_cu_seq_lens = \ - torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) - context_chunk_max_seq_lens = \ - chunk_seq_lens.max(dim=1).values.tolist() - context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() - assert max(context_chunk_seq_tot) <= \ - self.context_chunk_workspace_size - - return self.runner.attn_backend.make_metadata( - # Required by ModelRunner - use_cuda_graph=use_captured_graph, # Not Attention Related - # Required by Attention Metadata - num_prefills=self.num_prefills, - slot_mapping=slot_mapping_tensor, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - # Required by Attention Metadata (not used) - multi_modal_placeholder_index_maps=None, # Not Attention Related - enable_kv_scales_calculation=False, - # MLACommonMetadata - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_decode_query_len=max_decode_query_len, - max_prefill_seq_len=max_prefill_seq_len, - max_decode_seq_len=max_decode_seq_len, - query_start_loc=query_start_loc_tensor, - seq_start_loc=seq_start_loc_tensor, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - head_dim=self.runner.model_config.get_head_size(), - is_profile_run=self.runner.in_profile_run, - # MLACommonMetadata Chunk prefill specific - context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, - context_chunk_starts=context_chunk_starts, - context_chunk_seq_tot=context_chunk_seq_tot, - context_chunk_max_seq_lens=context_chunk_max_seq_lens, - ) - - -class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - q_lora_rank: Optional[int], - kv_lora_rank: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - qk_head_dim: int, - v_head_dim: int, - kv_b_proj: ColumnParallelLinear, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing not supported in V0.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.kv_cache_dtype = kv_cache_dtype - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_head_dim - self.v_head_dim = v_head_dim - self.kv_b_proj = kv_b_proj - - self.triton_fa_func = triton_attention - # Handle the differences between the flash_attn_varlen from flash_attn - # and the one from vllm_flash_attn. The former is used on RoCM and the - # latter has an additional parameter to control FA2 vs FA3 - self.flash_attn_varlen_func = flash_attn_varlen_func - self.vllm_flash_attn_version = get_flash_attn_version() - if self.vllm_flash_attn_version is not None: - self.flash_attn_varlen_func = \ - functools.partial(flash_attn_varlen_func, - fa_version=self.vllm_flash_attn_version) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim for attention backends that do - # not support different headdims - # We don't need to pad V if we are on a hopper system with FA3 - self._pad_v = self.vllm_flash_attn_version is None or not ( - self.vllm_flash_attn_version == 3 - and current_platform.get_device_capability()[0] == 9) - - def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale, - return_softmax_lse, **kwargs): - maybe_padded_v = v - if self._pad_v: - maybe_padded_v = torch.nn.functional.pad( - v, [0, q.shape[-1] - v.shape[-1]], value=0) - - if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \ - and not return_softmax_lse: - attn_out = self.triton_fa_func( - q, - k, - maybe_padded_v, - None, # output - kwargs["cu_seqlens_q"], - kwargs["cu_seqlens_k"], - kwargs["max_seqlen_q"], - kwargs["max_seqlen_k"], - kwargs["causal"], - softmax_scale, - None, # bias - ) - elif is_vllm_fa: - attn_out = self.flash_attn_varlen_func( - q=q, - k=k, - v=maybe_padded_v, - return_softmax_lse=return_softmax_lse, - softmax_scale=softmax_scale, - **kwargs, - ) - else: - # Use return_attn_probs instead of return_softmax_lse for RoCM - attn_out = self.flash_attn_varlen_func( - q=q, - k=k, - v=maybe_padded_v, - return_attn_probs=return_softmax_lse, - softmax_scale=softmax_scale, - **kwargs, - ) - - # Unpack the output if there is multiple results, - # triton always returns (output, softmax_lse), - # vllm_flash_attn returns (output, softmax_lse) when - # `return_softmax_lse = True` - # flash_attn (RoCM) returns (output, softmax_lse, ...) when - # `return_attn_probs = True` - rest = None - if isinstance(attn_out, tuple): - attn_out, *rest = attn_out - - # Remain consistent with old `flash_attn_varlen_func` where there - # is only one output tensor if `return_softmax_lse` is False. - if return_softmax_lse: - assert rest is not None - return attn_out, rest[0] - return attn_out - - def _v_up_proj(self, x): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - - def process_weights_after_loading(self, act_dtype: torch.dtype): - - def get_layer_weight(layer): - WEIGHT_NAMES = ("weight", "qweight", "weight_packed") - for attr in WEIGHT_NAMES: - if hasattr(layer, attr): - return getattr(layer, attr) - raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") - - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight - - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) - - def _compute_prefill_context( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - k_scale: torch.Tensor, - ): - prefill_metadata = attn_metadata.prefill_metadata - assert prefill_metadata is not None - assert prefill_metadata.context_chunk_seq_tot is not None - assert prefill_metadata.context_chunk_cu_seq_lens is not None - assert prefill_metadata.context_chunk_starts is not None - assert prefill_metadata.context_chunk_max_seq_lens is not None - assert prefill_metadata.context_lens_tensor is not None - - output = None - iters = len(prefill_metadata.context_chunk_seq_tot) - - # Fetch from attn_metadata directly, since it late bound by - # MLAAttentionState, grabbing it directly `attn_metadata` can avoid - # any weirdness around prefill_metadata caching - assert attn_metadata.context_chunk_workspace is not None - workspace = attn_metadata.context_chunk_workspace - - for i in range(iters): - toks = prefill_metadata.context_chunk_seq_tot[i] - - ops.gather_and_maybe_dequant_cache( - src_cache=kv_c_and_k_pe_cache, - dst=workspace, - block_table=prefill_metadata.block_tables, - cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], - batch_size=prefill_metadata.num_prefills, - kv_cache_dtype=self.kv_cache_dtype, - scale=k_scale, - seq_starts=prefill_metadata.context_chunk_starts[i], - ) - - kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank] - k_pe = workspace[:toks]\ - [..., self.kv_lora_rank:].unsqueeze(1) - - kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) - - attn_output, attn_softmax_lse = \ - self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - ) - - if output is None: - output = attn_output - output_lse = attn_softmax_lse - else: - output_tmp = torch.empty_like(output) - output_lse_tmp = torch.empty_like(output_lse) - merge_attn_states( - output=output_tmp, - output_lse=output_lse_tmp, - prefix_output=output, - prefix_lse=output_lse, - suffix_output=attn_output, - suffix_lse=attn_softmax_lse, - ) - output = output_tmp - output_lse = output_lse_tmp - - return output, output_lse - - def _forward_prefill( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - k_scale: torch.Tensor, - ) -> torch.Tensor: - - prefill_metadata = attn_metadata.prefill_metadata - assert prefill_metadata is not None - - has_context = prefill_metadata.context_lens_tensor is not None \ - and prefill_metadata.context_lens_tensor.max() > 0 - - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - - output = self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.query_start_loc, - max_seqlen_q=prefill_metadata.max_prefill_seq_len, - max_seqlen_k=prefill_metadata.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=has_context, - ) - - if has_context: - # ROCm flash_attn_varlen_func will return 3 objects instead of 2 - suffix_output, suffix_lse = output - context_output, context_lse = self._compute_prefill_context( \ - q, kv_c_and_k_pe_cache, attn_metadata, k_scale) - - output = torch.empty_like(suffix_output) - merge_attn_states( - output=output, - prefix_output=context_output, - prefix_lse=context_lse, - suffix_output=suffix_output, - suffix_lse=suffix_lse, - ) - - # unpad if necessary - if self._pad_v: - output = output[..., :v.shape[-1]] - - return output.flatten(start_dim=-2) - - @abstractmethod - def _forward_decode( - self, - ql_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: T, - ) -> torch.Tensor: - raise NotImplementedError - - def forward( - self, - layer: AttentionLayer, - q: torch.Tensor, # query in unified attn - k_c_normed: torch.Tensor, # key in unified attn - k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, - attn_metadata: T, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if output is not None: - raise NotImplementedError( - "output is not yet supported for MLAImplBase") - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for MLAImplBase") - - if attn_metadata.is_profile_run and \ - attn_metadata.context_chunk_workspace is not None: - # During the profile run try to simulate to worse case output size - # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` - # since this can be large - _ = torch.empty( - (attn_metadata.context_chunk_workspace.shape[0], - self.num_heads, self.qk_nope_head_dim + self.v_head_dim), - device=k_c_normed.device, - dtype=k_c_normed.dtype, - ) - - has_decode = attn_metadata.decode_metadata is not None - has_prefill = attn_metadata.prefill_metadata is not None - - num_prefill_tokens: int = attn_metadata.num_prefill_tokens - q = q.view(-1, self.num_heads, self.qk_head_dim) - - decode_q = q[num_prefill_tokens:] - - prefill_q = q[:num_prefill_tokens] - prefill_k_pe = k_pe[:num_prefill_tokens] - prefill_k_c_normed = k_c_normed[:num_prefill_tokens] - - # write the latent and rope to kv cache - if kv_cache.numel() > 0: - ops.concat_and_cache_mla( - k_c_normed, - k_pe.squeeze(1), - kv_cache, - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype=self.kv_cache_dtype, - scale=layer._k_scale, - ) - - output = torch.empty(attn_metadata.num_prefill_tokens + - attn_metadata.num_decode_tokens, - self.v_head_dim * self.num_heads, - device=q.device, - dtype=q.dtype) - if has_prefill: - output[:num_prefill_tokens] = self._forward_prefill( - prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, - attn_metadata, layer._k_scale) - - if has_decode: - decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Convert from (B, N, P) to (N, B, P) - decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) - - output[num_prefill_tokens:] = self._forward_decode( - decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) - - return output diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py deleted file mode 100644 index 587d08858b92..000000000000 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ /dev/null @@ -1,407 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from contextlib import contextmanager -from dataclasses import dataclass -from typing import Optional, Type, Union - -import torch - -import vllm.envs as envs -from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata, - MLACommonMetadataBuilder, - MLACommonState) -from vllm.attention.backends.utils import (compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) -from vllm.attention.ops.rocm_aiter_mla import (aiter_mla_decode_fwd, - get_aiter_mla_metadata) - - -def is_aiter_mla_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_ROCM_USE_AITER_MLA - - -class AiterMLABackend(MLACommonBackend): - - @staticmethod - def get_name() -> str: - return "ROCM_AITER_MLA" - - @staticmethod - def get_impl_cls() -> Type["AiterMLAImpl"]: - return AiterMLAImpl - - @staticmethod - def get_metadata_cls() -> Type["AiterMLAMetadata"]: - return AiterMLAMetadata - - @staticmethod - def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]: - return AiterMLAMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["AiterMLAState"]: - return AiterMLAState - - -@dataclass -class AiterMLAMetadata(MLACommonMetadata): - # The following 5 tensors are for current version of AITER MLA - block_table_bound: Optional[torch.Tensor] = None - # The indptr of the paged kv cache, shape: [batch_size + 1] - paged_kv_indptr: Optional[torch.Tensor] = None - # The page indices of the paged kv cache - paged_kv_indices: Optional[torch.Tensor] = None - # The number of entries in the last page of each request in - # the paged kv cache, shape: [batch_size] - paged_kv_last_page_lens: Optional[torch.Tensor] = None - - # This is just to make new AITER MLA API work - # -- MTP support is not added yet. - qo_indptr: Optional[torch.Tensor] = None - - @property - def prefill_metadata(self): - prefill_metadata = super().prefill_metadata - self._cached_prefill_metadata = prefill_metadata - - if prefill_metadata is not None: - prefill_metadata.paged_kv_indptr = self.paged_kv_indptr - prefill_metadata.paged_kv_indices = self.paged_kv_indices - prefill_metadata\ - .paged_kv_last_page_lens = self.paged_kv_last_page_lens - prefill_metadata.block_table_bound = self.block_table_bound - prefill_metadata.qo_indptr = self.qo_indptr - - # update the cache - self._cached_prefill_metadata = self.__class__( - **prefill_metadata.__dict__) - - return self._cached_prefill_metadata - - @property - def decode_metadata(self): - decode_metadata = super().decode_metadata - - self._cached_decode_metadata = decode_metadata - - if decode_metadata is not None: - decode_metadata.paged_kv_indptr = self.paged_kv_indptr - decode_metadata.paged_kv_indices = self.paged_kv_indices - decode_metadata\ - .paged_kv_last_page_lens = self.paged_kv_last_page_lens - decode_metadata.block_table_bound = self.block_table_bound - decode_metadata.qo_indptr = self.qo_indptr - - # update the cache - self._cached_decode_metadata = self.__class__( - **decode_metadata.__dict__) - - return self._cached_decode_metadata - - -class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): - BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] - - def __init__(self, input_builder): - super().__init__(input_builder) - assert self.block_size == 1, "AITER MLA requires only block size 1." - - def prepare(self): - super().prepare() - self.paged_kv_indices: list[int] = [] - self.paged_kv_indptr: list[int] = [0] - self.paged_kv_last_page_lens: list[int] = [] - self.total_blocks = 0 - self.qo_indptr: list[int] = [0] - - def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, - prefix_cache_hit: bool): - """Add a sequence group to the metadata. Specifically update/append - 1. context length. - 2. block table. - 3. slot mapping. - """ - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables - - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): - self.context_lens.append(context_len) - if is_prompt: - self.num_prefills += 1 - self.num_prefill_tokens += token_len - self.prefill_seq_lens.append(seq_len) - else: - self.num_decode_tokens += query_len - self.curr_seq_lens.append(curr_seq_len) - - # Compute block table. - # TODO(sang): Combine chunked prefill and prefix caching by - # only allowing multiple of block_size chunk size. - # NOTE: This only works for oooooooxxx style attention. - block_table = [] - if prefix_cache_hit: - # NOTE(woosuk): For flash-attn, the block table should - # include the entries for the incoming prefill tokens. - block_table = block_tables[seq_id] - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): - if curr_sliding_window_block == 0: - block_table = block_tables[seq_id] - else: - block_table = block_tables[seq_id][ - -curr_sliding_window_block:] - self.block_tables.append(block_table) - - # Compute slot mapping. - is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, - context_len, - self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) - if is_profile_run: - return - - # Update paged_kv_* tensors only for non-profile run - block_table = block_tables[seq_id] - self._update_paged_kv_tensors(block_table, seq_len) - - def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): - # Get the number of valid blocks based on sequence length. - # If seq_len = 16, block_size = 16, - # block_table_bound is 1 with 1 valid block. - # If seq_len = 15, block_size = 16, - # block_table_bound is 0 + 1 with 1 valid block. - self.total_blocks += len(block_table) - block_table_bound = seq_len // self.block_size + 1 \ - if seq_len % self.block_size != 0 \ - else seq_len // self.block_size - self.paged_kv_indices.extend(block_table[:block_table_bound]) - self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + - block_table_bound) - self.qo_indptr.append(self.qo_indptr[-1] + 1) - - last_page_len = seq_len % self.block_size - if last_page_len == 0: - last_page_len = self.block_size - self.paged_kv_last_page_lens.append(last_page_len) - - def build(self, seq_lens: list[int], query_lens: list[int], - cuda_graph_pad_size: int, batch_size: int) -> AiterMLAMetadata: - metadata = super().build(seq_lens, query_lens, cuda_graph_pad_size, - batch_size) - device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 - - if use_captured_graph: - last_paged_kv_indptr = self.paged_kv_indptr[-1] - self.paged_kv_indptr.extend([last_paged_kv_indptr] * - cuda_graph_pad_size) - self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) - last_qo_indptr = self.qo_indptr[-1] - self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size) - - # For current version of AITER MLA - if len(self.paged_kv_indptr) > 0: - # extend to the maximum number of blocks as returned by the - # scheduler - self.paged_kv_indices.extend( - [0] * (self.total_blocks - len(self.paged_kv_indices))) - paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, - device=device, - dtype=torch.int) - paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, - device=device, - dtype=torch.int) - paged_kv_last_page_lens_tensor = torch.tensor( - self.paged_kv_last_page_lens, device=device, dtype=torch.int) - block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - - 1, - device=device, - dtype=torch.int) - - qo_indptr = torch.tensor(self.qo_indptr, - device=device, - dtype=torch.int) - else: - paged_kv_indices_tensor = None - paged_kv_indptr_tensor = None - paged_kv_last_page_lens_tensor = None - block_table_bound_tensor = None - qo_indptr = None - - metadata.paged_kv_indptr = paged_kv_indptr_tensor - metadata.paged_kv_indices = paged_kv_indices_tensor - metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor - metadata.block_table_bound = block_table_bound_tensor - metadata.qo_indptr = qo_indptr - - return metadata - - -class AiterMLAState(MLACommonState[AiterMLAMetadata]): - - @contextmanager - def graph_capture(self, max_batch_size: int): - kv_indices, kv_indptr, last_page_lens, qo_indptr = \ - get_aiter_mla_metadata( - max_batch_size=max_batch_size, - block_size=self.runner.block_size, - max_block_per_batch=\ - self.runner.get_max_block_per_batch(), - device=self.runner.device) - self._paged_kv_indices_tensor = kv_indices - self._paged_kv_indptr_tensor = kv_indptr - self._paged_kv_last_page_lens_tensor = last_page_lens - self._qo_indptr_tensor = qo_indptr - - with super().graph_capture(max_batch_size): - yield - - del self._paged_kv_indices_tensor - del self._paged_kv_indptr_tensor - del self._paged_kv_last_page_lens_tensor - del self._qo_indptr_tensor - - def graph_capture_get_metadata_for_batch( - self, - batch_size: int, - is_encoder_decoder_model: bool = False) -> AiterMLAMetadata: - - metadata = super().graph_capture_get_metadata_for_batch( - batch_size, is_encoder_decoder_model) - - paged_kv_indptr = self._paged_kv_indptr_tensor[:batch_size + 1] - paged_kv_indices = self._paged_kv_indices_tensor - paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: - batch_size] - qo_indptr = self._qo_indptr_tensor[:batch_size + 1] - - metadata.paged_kv_indptr = paged_kv_indptr - metadata.paged_kv_indices = paged_kv_indices - metadata.paged_kv_last_page_lens = paged_kv_last_page_lens - metadata.qo_indptr = qo_indptr - - return metadata - - def get_graph_input_buffers(self, - attn_metadata: AiterMLAMetadata, - is_encoder_decoder_model: bool = False): - input_buffers = super().get_graph_input_buffers( - attn_metadata, is_encoder_decoder_model) - input_buffers[ - 'paged_kv_indptr'] = attn_metadata.decode_metadata.paged_kv_indptr - input_buffers[ - "paged_kv_indices"] = attn_metadata.\ - decode_metadata.paged_kv_indices - input_buffers[ - "paged_kv_last_page_lens"] = attn_metadata.\ - decode_metadata.paged_kv_last_page_lens - input_buffers['qo_indptr'] = attn_metadata.qo_indptr - - return input_buffers - - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata: AiterMLAMetadata, - is_encoder_decoder_model: bool = False): - super().prepare_graph_input_buffers(input_buffers, attn_metadata, - is_encoder_decoder_model) - - num_total_blocks = attn_metadata.decode_metadata.paged_kv_indices.shape[ - 0] - input_buffers["paged_kv_indptr"].copy_( - attn_metadata.decode_metadata.paged_kv_indptr, non_blocking=True) - input_buffers["paged_kv_indices"][:num_total_blocks].copy_( - attn_metadata.decode_metadata.paged_kv_indices, non_blocking=True) - input_buffers["paged_kv_last_page_lens"].copy_( - attn_metadata.decode_metadata.paged_kv_last_page_lens, - non_blocking=True) - input_buffers["qo_indptr"].copy_( - attn_metadata.decode_metadata.qo_indptr, non_blocking=True) - - -class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - - unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] - if any(unsupported_features): - raise NotImplementedError( - "Aiter MLA does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") - - from aiter import flash_attn_varlen_func - self.flash_attn_varlen_func = flash_attn_varlen_func - - def _flash_attn_varlen_diff_headdims( - self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - softmax_scale: float, return_softmax_lse: bool, - **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: - output = self.flash_attn_varlen_func( - q, - k, - v, - **kwargs, - ) - - return output - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: AiterMLAMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - B = q_nope.shape[0] - - q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.empty(B, - self.num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) - - kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - - aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, - attn_metadata.qo_indptr, - attn_metadata.max_query_len, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_lens) - - return self._v_up_proj(o) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py deleted file mode 100644 index 9262144e37b5..000000000000 --- a/vllm/attention/backends/rocm_flash_attn.py +++ /dev/null @@ -1,953 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer ROCm GPUs.""" -import itertools -from dataclasses import dataclass -from functools import cache -from typing import List, Optional, Tuple, Type - -import torch - -import vllm.envs as envs -from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import (CommonAttentionState, - CommonMetadataBuilder) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) -from vllm.config import get_current_vllm_config -from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym) -from vllm.platforms import current_platform - -logger = init_logger(__name__) -_PARTITION_SIZE_ROCM = 256 - - -@cache -def is_rocm_aiter_paged_attn_enabled() -> bool: - return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \ - and envs.VLLM_ROCM_USE_AITER \ - - -@cache -def _get_paged_attn_module() -> PagedAttention: - """ - Initializes the appropriate PagedAttention module from `attention/ops`, - which is used as helper function - by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. - - The choice of attention module depends on whether - AITER paged attention is enabled: - - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. - - Otherwise, it defaults to using the original `PagedAttention`. - """ - if is_rocm_aiter_paged_attn_enabled(): - # Import AITERPagedAttention only when the flag is enabled - from vllm.attention.ops.rocm_aiter_paged_attn import ( - AITERPagedAttention) - return AITERPagedAttention() - return PagedAttention() - - -class ROCmFlashAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True - - @staticmethod - def get_name() -> str: - return "ROCM_FLASH" - - @staticmethod - def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: - return ROCmFlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return ROCmFlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: - return ROCmFlashAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - paged_attn = _get_paged_attn_module() - return paged_attn.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - paged_attn = _get_paged_attn_module() - paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - paged_attn = _get_paged_attn_module() - paged_attn.copy_blocks(kv_caches, src_to_dists) - - -@dataclass -class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): - """Metadata for FlashAttentionBackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] = None - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None - _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - @property - def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.block_tables is not None - - self._cached_prefill_metadata = ROCmFlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cached_decode_metadata = ROCmFlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - # Batch may be composed of prefill|decodes, adjust query start indices - # to refer to the start of decodes when the two are split apart. - # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - if self._cached_decode_metadata.query_start_loc is not None: - qs = self._cached_decode_metadata.query_start_loc - self._cached_decode_metadata.query_start_loc = qs - qs[0] - return self._cached_decode_metadata - - -class ROCmFlashAttentionMetadataBuilder( - CommonMetadataBuilder[ROCmFlashAttentionMetadata]): - - _metadata_cls = ROCmFlashAttentionMetadata - - -def _make_alibi_bias(alibi_slopes: torch.Tensor, - dtype: torch.dtype, - seq_lens: Optional[List[int]], - make_attn_mask: bool = True) -> List[torch.Tensor]: - attn_biases = [] - if seq_lens: - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - num_heads = alibi_slopes.shape[0] - bias = bias[None, :].repeat( - (num_heads, 1, 1)).to(alibi_slopes.device) - bias.mul_(alibi_slopes[:, None, None]) - if make_attn_mask: - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( - alibi_slopes.device) - attn_biases.append((bias + inf_mask).to(dtype)) - else: - attn_biases.append(bias.to(dtype)) - - return attn_biases - - -def _get_seq_len_block_table_args( - attn_metadata: ROCmFlashAttentionMetadata, - attn_type: str, -) -> tuple: - ''' - The particular choice of sequence-length - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths - Encoder attn -> select encoder sequence lengths fields - Encoder-only attn -> select prefill sequence lengths with - bidirectional attention - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention op - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention, encoder-only - - Returns: - - * Appropriate sequence-lengths tensors for query and key - * Appropriate max sequence-length scalar - * Causal masking flag - ''' - - if attn_type == AttentionType.ENCODER: - assert attn_metadata.encoder_seq_lens is not None - assert attn_metadata.encoder_seq_lens_tensor is not None - query_seq_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), - device=attn_metadata.encoder_seq_lens_tensor.device, - dtype=attn_metadata.encoder_seq_lens_tensor.dtype) - causal_mask = False - - # No block tables associated with encoder attention - return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, - query_seq_start_loc, attn_metadata.max_encoder_seq_len, - attn_metadata.encoder_seq_lens, causal_mask) - - elif attn_type == AttentionType.ENCODER_ONLY: - # For encoder-only models, we use the prefill sequence lengths - assert attn_metadata.seq_lens is not None - assert attn_metadata.seq_lens_tensor is not None - query_seq_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.seq_lens)), - device=attn_metadata.seq_lens_tensor.device, - dtype=attn_metadata.seq_lens_tensor.dtype) - max_seq_len = attn_metadata.max_prefill_seq_len - # Encoder-only models typically use bidirectional attention - causal_mask = False - - return (query_seq_start_loc, max_seq_len, query_seq_start_loc, - max_seq_len, attn_metadata.seq_lens, causal_mask) - - elif attn_type == AttentionType.DECODER: - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - assert attn_metadata.seq_lens is not None - assert attn_metadata.seq_lens_tensor is not None - query_seq_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.seq_lens)), - device=attn_metadata.seq_lens_tensor.device, - dtype=attn_metadata.seq_lens_tensor.dtype) - max_seq_len = attn_metadata.max_prefill_seq_len - causal_mask = True - - return (query_seq_start_loc, max_seq_len, query_seq_start_loc, - max_seq_len, attn_metadata.seq_lens, causal_mask) - elif attn_type == AttentionType.ENCODER_DECODER: - assert attn_metadata.seq_lens is not None - assert attn_metadata.encoder_seq_lens_tensor is not None - query_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.seq_lens)), - device=attn_metadata.encoder_seq_lens_tensor.device, - dtype=attn_metadata.encoder_seq_lens_tensor.dtype) - - assert attn_metadata.encoder_seq_lens is not None - assert attn_metadata.seq_lens_tensor is not None - key_seq_start_loc = torch.tensor( - list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)), - device=attn_metadata.seq_lens_tensor.device, - dtype=attn_metadata.seq_lens_tensor.dtype) - causal_mask = False - - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - return (query_start_loc, attn_metadata.max_prefill_seq_len, - key_seq_start_loc, attn_metadata.max_encoder_seq_len, - attn_metadata.seq_lens, causal_mask) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -class ROCmFlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| - - Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| - |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "ROCM_FLASH backend.") - if use_irope: - logger.warning_once( - "Using irope in ROCm Flash Attention is not supported yet, it " - "will fail back to global attention for long context.") - if use_irope: - logger.warning( - "Using irope in V0 is not supported yet, it will fall back " - "to global attention for long context.") - if logits_soft_cap is None: - # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - self.logits_soft_cap = 0.0 - else: - self.logits_soft_cap = logits_soft_cap - self.attn_type = attn_type - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) - self.kv_cache_dtype = kv_cache_dtype - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - self.paged_attn_module = _get_paged_attn_module() - supported_head_sizes = self.paged_attn_module.get_supported_head_sizes( - ) - - if head_size not in supported_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}.") - - self.use_naive_attn = False - # NOTE: Allow for switching between Triton and CK. Defaulting to triton. - self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN - if self.use_triton_flash_attn: - if logits_soft_cap is not None: - raise ValueError( - "ROCm Triton FlashAttention does not support attention" - " logits soft capping." - " please try using the ROCm CK " - "FA backend instead by setting the env var " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") - - from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 - triton_attention) - self.triton_attn_func = triton_attention - logger.debug("Using Triton FA in ROCmBackend") - if self.sliding_window != (-1, -1): - logger.warning("ROCm Triton FA does not currently support " - "sliding window attention. If using half " - "precision, please try using the ROCm CK " - "FA backend instead by setting the env var " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") - else: - # if not using triton, navi3x/navi21/navi10 do not use flash-attn - # either - if not current_platform.has_device_capability(90): - self.use_naive_attn = True - else: - try: - from flash_attn import flash_attn_varlen_func # noqa: F401 - self.fa_attn_func = flash_attn_varlen_func - logger.debug("Using CK FA in ROCmBackend") - except ModuleNotFoundError: - self.use_naive_attn = True - - if self.use_naive_attn: - if logits_soft_cap is not None: - raise ValueError( - "ROCm Naive FlashAttention does not support " - "attention logits soft capping.") - - self.sdpa_attn_func = _sdpa_attention - logger.debug("Using naive (SDPA) attention in ROCmBackend") - - self.aiter_kv_scales_initialized = False - self.force_fp8_attention = ( - get_current_vllm_config() is not None - and get_current_vllm_config().model_config.override_attention_dtype - == "fp8") - - def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" - tokens, n_kv_heads, head_dim = x.shape - return (x[:, :, - None, :].expand(tokens, n_kv_heads, n_rep, - head_dim).reshape(tokens, n_kv_heads * n_rep, - head_dim)) - - def fused_output_quant_supported(self, quant_key: QuantKey): - if self.use_triton_flash_attn: - return quant_key == kFp8StaticTensorSym - - # Only supported in the Triton backend - return False - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: ROCmFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. - - For decoder-only models: query, key and value must be non-None. - - For encoder/decoder models: - * ROCmFlashAttentionImpl.forward() may be invoked for both self- and - cross-attention layers. - * For self-attention: query, key and value must be non-None. - * For cross-attention: - * Query must be non-None - * During prefill, key and value must be non-None; key and value - get cached for use during decode. - * During decode, key and value may be None, since: - (1) key and value tensors were cached during prefill, and - (2) cross-attention key and value tensors do not grow during - decode - - A note on how the attn_type (attention type enum) argument impacts - attention forward() behavior: - - * DECODER: normal decoder-only behavior; - use decoder self-attention block table - * ENCODER: no KV caching; pass encoder sequence - attributes (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) to kernel, in lieu of decoder - sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) - * ENCODER_DECODER: cross-attention behavior; - use cross-attention block table for caching KVs derived - from encoder hidden states; since KV sequence lengths - will match encoder sequence lengths, pass encoder sequence - attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) - * ENCODER_ONLY: bidirectional attention with no KV caching; - use prefill sequence attributes - - Args: - layer: Attention layer instance. - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache: KV cache tensor with shape - [2, num_blocks, block_size * num_kv_heads * head_size]. - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - output: Optional output tensor. - output_scale: Optional output scale tensor. - output_block_scale: Optional output block scale tensor. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - assert output is not None, "Output tensor must be provided." - - if output_scale is not None and not self.use_triton_flash_attn: - raise NotImplementedError( - "fused output quantization only supported for Triton" - " implementation in ROCMFlashAttentionImpl for now") - - if output_block_scale is not None: - raise NotImplementedError( - "fused nvfp4 output quantization is not supported" - " for ROCMFlashAttentionImpl") - - query = query.view(-1, self.num_heads, self.head_size) - if key is not None: - assert value is not None - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - else: - assert value is None - - paged_attn = self.paged_attn_module - - # Reshaping kv tensors is required for AITER paged attention kernel - # because it works on a different tensor shape, - # when the size of one element is one byte (int8/fp8 dtypes). - # This reshaping is only required on the first forward call - # and the kv cache must not be empty. - if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1 - and not self.aiter_kv_scales_initialized - and kv_cache.shape != torch.Size([0])): - num_blocks = kv_cache.shape[1] - block_size = kv_cache.shape[2] // (self.num_kv_heads * - self.head_size) - k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), - dtype=torch.float32, - device=kv_cache.device) - v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), - dtype=torch.float32, - device=kv_cache.device) - self.aiter_kv_scales_initialized = True - k_scale.fill_(layer._k_scale.item()) - v_scale.fill_(layer._v_scale.item()) - layer._k_scale = k_scale - layer._v_scale = v_scale - - # Only update KV cache for decoder self-attention - # and encoder-decoder cross-attention - if self.attn_type not in [ - AttentionType.ENCODER, AttentionType.ENCODER_ONLY - ] and kv_cache.numel() > 0: - key_cache, value_cache = paged_attn.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - if key is not None and value is not None: - # Reshape the input keys and values and store them in the - # cache. If kv_cache is not provided, the new key and value - # tensors are not cached. This happens during the initial - # memory profiling run. - paged_attn.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping - if self.attn_type != AttentionType.ENCODER_DECODER else - attn_metadata.cross_slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if self.attn_type != AttentionType.ENCODER: - num_prefill_tokens = attn_metadata.num_prefill_tokens - elif self.attn_type == AttentionType.ENCODER_ONLY: - # For encoder-only models, all tokens are processed in one go - num_prefill_tokens = query.shape[0] - else: - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - - # For encoder-only and encoder models, - # we process all tokens at once - # For decoder and encoder-decoder, - # we may need to limit key/value to prefill tokens - if key is not None and value is not None \ - and self.attn_type not in [AttentionType.ENCODER_DECODER, - AttentionType.ENCODER_ONLY]: - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - # normal attention and DECODER - if self.attn_type == AttentionType.DECODER and ( - kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, - key_max_seq_len, seq_lens, - causal_mask) = (prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - attn_metadata.seq_lens, True) - # prefix-enabled attention and ENCODER/ENCODER_DECODER - else: - (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, - key_max_seq_len, seq_lens, - causal_mask) = _get_seq_len_block_table_args( - prefill_meta, self.attn_type) - # Prompt run. - if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: - # triton attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - attn_masks = None - if self.use_triton_flash_attn: - if self.alibi_slopes is not None: - attn_masks = _make_alibi_bias( - self.alibi_slopes, - query.dtype, - seq_lens, - make_attn_mask=causal_mask) # type: ignore - - use_fp8_scales = (layer._q_scale and layer._k_scale - and layer._v_scale and layer._prob_scale - and (self.kv_cache_dtype == "fp8" - or self.force_fp8_attention)) - - full_scales = ( - layer._q_scale.item(), layer._k_scale.item(), - layer._v_scale.item(), - layer._prob_scale.item()) if use_fp8_scales else None - self.triton_attn_func( - query, - key, - value, - output[:num_prefill_tokens], - query_seq_start_loc, - key_seq_start_loc, - query_max_seq_len, - key_max_seq_len, - causal_mask, - self.scale, - attn_masks[0][None] - if attn_masks is not None else None, - full_scales, - output_scale, - ) - elif self.use_naive_attn: - if self.num_kv_heads != self.num_heads: - # Interleave for MQA workaround. - key = self.repeat_kv(key, self.num_queries_per_kv) - value = self.repeat_kv(value, self.num_queries_per_kv) - if self.alibi_slopes is not None: - attn_masks = _make_alibi_bias( - self.alibi_slopes, - query.dtype, - attn_metadata.seq_lens, - make_attn_mask=causal_mask) # type: ignore - query = query.movedim(0, query.dim() - 2) - key = key.movedim(0, key.dim() - 2) - value = value.movedim(0, value.dim() - 2) - # sdpa math backend attention - self.sdpa_attn_func( - query, - key, - value, - output[:num_prefill_tokens], - query_seq_start_loc, - num_prefill_tokens, - self.num_heads, - self.head_size, - self.scale, - attn_masks, - ) - else: - # upstream FA does not support an output arg, copy - output[:num_prefill_tokens] = self.fa_attn_func( - q=query, - k=key, - v=value, - cu_seqlens_q=query_seq_start_loc, - cu_seqlens_k=key_seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=key_max_seq_len, - softmax_scale=self.scale, - causal=causal_mask, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ) - - else: - # prefix-enabled attention - - # not applicable for encoder-only models - if self.attn_type != AttentionType.ENCODER_ONLY: - output[:num_prefill_tokens] = paged_attn.forward_prefix( - query, - key, - value, - self.kv_cache_dtype, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window[0], - layer._k_scale, - layer._v_scale, - ) - # Skip decode phase for encoder-only models - if (decode_meta := attn_metadata.decode_metadata) and ( - self.attn_type != AttentionType.ENCODER_ONLY): - # Decoding run. - # Whether to use rocm custom paged attention or not - num_seqs, num_heads, head_size = decode_query.shape - block_size = value_cache.shape[3] - gqa_ratio = num_heads // self.num_kv_heads - from vllm.platforms.rocm import use_rocm_custom_paged_attention - use_custom = use_rocm_custom_paged_attention( - decode_query.dtype, head_size, block_size, gqa_ratio, - decode_meta.max_decode_seq_len, self.sliding_window, - self.kv_cache_dtype, self.alibi_slopes) - - if use_custom: - max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type - != AttentionType.ENCODER_DECODER else - decode_meta.max_encoder_seq_len) - assert max_seq_len is not None - max_num_partitions = ( - (max_seq_len + _PARTITION_SIZE_ROCM - 1) // - _PARTITION_SIZE_ROCM) - assert _PARTITION_SIZE_ROCM % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=query.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - - query_start_loc = None - ops.paged_attention_rocm( - output[num_prefill_tokens:], - exp_sums, - max_logits, - tmp_output, - decode_query, - key_cache, - value_cache, - self.num_kv_heads, - self.scale, - decode_meta.block_tables - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.cross_block_tables, - decode_meta.seq_lens_tensor - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.encoder_seq_lens_tensor, - query_start_loc, - block_size, - max_seq_len, - self.alibi_slopes, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - output_scale, - ) - else: - # PagedAttention does not support fused quant, manually quantize - if output_scale is None: - out_pa = output[num_prefill_tokens:] - else: - out_pa = torch.empty_like(output[num_prefill_tokens:], - dtype=query.dtype) - - out_pa[:] = paged_attn.forward_decode( - decode_query, - key_cache, - value_cache, - decode_meta.block_tables - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.cross_block_tables, - decode_meta.seq_lens_tensor - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.encoder_seq_lens_tensor, - decode_meta.max_decode_seq_len - if self.attn_type != AttentionType.ENCODER_DECODER else - decode_meta.max_encoder_seq_len, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - layer._k_scale, - layer._v_scale, - ) - - # Manually perform quantization - if output_scale is not None: - out_uq = out_pa.view(-1, self.num_heads * self.head_size) - out_q = output.view(-1, self.num_heads * self.head_size) - ops.scaled_fp8_quant(out_uq, - output_scale, - output=out_q[num_prefill_tokens:]) - - # Reshape the output tensor. - return output.view(-1, self.num_heads * self.head_size) - - -def _sdpa_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - seq_lens: torch.Tensor, - num_tokens: int, - num_heads: int, - head_size: int, - scale: float, - attn_masks: Optional[List[torch.Tensor]] = None, -) -> torch.Tensor: - start = 0 - assert output.shape == (num_tokens, num_heads, head_size) - assert output.dtype == query.dtype - assert output.device == query.device - - for i, seq_len in enumerate(seq_lens): - end = start + seq_len - with torch.nn.attention.sdpa_kernel( - torch.nn.attention.SDPBackend.MATH): - sub_out = torch.nn.functional.scaled_dot_product_attention( - query[:, start:end, :], - key[:, start:end, :], - value[:, start:end, :], - dropout_p=0.0, - is_causal=attn_masks is None, - attn_mask=attn_masks[i] if attn_masks else None, - scale=scale).movedim(query.dim() - 2, 0) - output[start:end, :, :] = sub_out - start = end - - return output diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py deleted file mode 100644 index fba5b5f6bca8..000000000000 --- a/vllm/attention/backends/triton_mla.py +++ /dev/null @@ -1,111 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.mla.common import (MLACommonBackend, - MLACommonImpl, - MLACommonMetadata) -from vllm.attention.ops.triton_decode_attention import decode_attention_fwd - - -class TritonMLABackend(MLACommonBackend): - - @staticmethod - def get_name() -> str: - return "TRITON_MLA" - - @staticmethod - def get_impl_cls() -> Type["TritonMLAImpl"]: - return TritonMLAImpl - - -class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - - unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] - if any(unsupported_features): - raise NotImplementedError( - "TritonMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, logits_soft_cap") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "TritonMLAImpl") - - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "TritonMLA with FP8 KV cache not yet supported") - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - B = q_nope.shape[0] - - q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.zeros(B, - self.num_heads, - self.kv_lora_rank, - dtype=q.dtype, - device=q.device) - - num_kv_splits = 4 # TODO: heuristic - - # TODO(lucas) Allocate ahead of time - attn_logits = torch.empty( - ( - B, - self.num_heads, - num_kv_splits, - # NOTE(lucas) idk why the +1 is here but sglang has it so we - # just mirror that - self.kv_lora_rank + 1, - ), - dtype=torch.float32, - device=q.device, - ) - - # Add a head dim of 1 - kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2) - kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] - PAGE_SIZE = kv_c_and_k_pe_cache.size(1) - - # Run MQA - decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, attn_logits, - num_kv_splits, self.scale, PAGE_SIZE) - - return self._v_up_proj(o) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index b28e6a4237cb..3f15580872c7 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -338,10 +338,9 @@ def graph_capture_get_metadata_for_batch( # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. assert self.runner.attn_backend.get_name() in \ - ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ - f"Expected attn_backend name to be either 'XFORMERS'," \ - f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ - f"got '{self.runner.attn_backend.get_name()}'" + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or " \ + f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'" self._update_captured_metadata_for_enc_dec_model( batch_size=batch_size, attn_metadata=attn_metadata) @@ -360,10 +359,9 @@ def get_graph_input_buffers( # The encoder decoder model works only with XFormers and # Flash Attention backend. Assert the same. assert self.runner.attn_backend.get_name() in \ - ["XFORMERS", "FLASH_ATTN", "ROCM_FLASH"], \ - f"Expected attn_backend name to be either 'XFORMERS'," \ - f"'ROCM_FLASH', or 'FLASH_ATTN', but " \ - f"got '{self.runner.attn_backend.get_name()}'" + ["XFORMERS", "FLASH_ATTN"], \ + f"Expected attn_backend name to be either 'XFORMERS' or " \ + f"'FLASH_ATTN', but got '{self.runner.attn_backend.get_name()}'" self._add_additional_input_buffers_for_enc_dec_model( attn_metadata=attn_metadata, input_buffers=input_buffers) return input_buffers diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py deleted file mode 100644 index 302d3d7ea903..000000000000 --- a/vllm/attention/backends/xformers.py +++ /dev/null @@ -1,805 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with xFormers and PagedAttention.""" -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type - -import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (AttentionBias, - BlockDiagonalCausalMask, - BlockDiagonalMask, - LowerTriangularMaskWithTensorBias) - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import ( - CommonAttentionState, CommonMetadataBuilder, - get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, - is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class XFormersBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "XFORMERS" - - @staticmethod - def get_impl_cls() -> Type["XFormersImpl"]: - return XFormersImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return XFormersMetadata - - @staticmethod - def get_builder_cls() -> Type["XFormersMetadataBuilder"]: - return XFormersMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], - ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) - - -@dataclass -class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): - """Metadata for XFormersbackend. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # FIXME: It is for flash attn. - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] = None - - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] = None - - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] = None - - # Max number of query tokens among request in the batch. - max_decode_query_len: Optional[int] = None - - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - - # Self-attention prefill/decode metadata cache - _cached_prefill_metadata: Optional["XFormersMetadata"] = None - _cached_decode_metadata: Optional["XFormersMetadata"] = None - - # Begin encoder attn & enc/dec cross-attn fields... - - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - encoder_seq_start_loc: Optional[torch.Tensor] = None - - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - def __post_init__(self): - # Set during the execution of the first attention op. - # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. - # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[List[AttentionBias]] = None - self.encoder_attn_bias: Optional[List[AttentionBias]] = None - self.cross_attn_bias: Optional[List[AttentionBias]] = None - - @property - def is_all_encoder_attn_metadata_set(self): - ''' - All attention metadata required for encoder attention is set. - ''' - return is_all_encoder_attn_metadata_set(self) - - @property - def is_all_cross_attn_metadata_set(self): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return is_all_cross_attn_metadata_set(self) - - @property - def prefill_metadata(self) -> Optional["XFormersMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - # Recover cached prefill-phase attention - # metadata structure - return self._cached_prefill_metadata - - assert ((self.seq_lens is not None) - or (self.encoder_seq_lens is not None)) - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[:self.num_prefill_tokens]) - seq_lens = (None if self.seq_lens is None else - self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) - block_tables = (None if self.block_tables is None else - self.block_tables[:self.num_prefills]) - - # Construct & cache prefill-phase attention metadata structure - self._cached_prefill_metadata = XFormersMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["XFormersMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - # Recover cached decode-phase attention - # metadata structure - return self._cached_decode_metadata - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - - # Compute some attn_metadata fields which default to None - slot_mapping = (None if self.slot_mapping is None else - self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) - block_tables = (None if self.block_tables is None else - self.block_tables[self.num_prefills:]) - - # Construct & cache decode-phase attention metadata structure - self._cached_decode_metadata = XFormersMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - seq_lens_tensor=seq_lens_tensor, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, - # Begin encoder & cross attn fields below... - encoder_seq_lens=self.encoder_seq_lens, - encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, - max_encoder_seq_len=self.max_encoder_seq_len, - cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables) - - # Batch may be composed of prefill|decodes, adjust query start indices - # to refer to the start of decodes when the two are split apart. - # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - if self._cached_decode_metadata.query_start_loc is not None: - qs = self._cached_decode_metadata.query_start_loc - self._cached_decode_metadata.query_start_loc = qs - qs[0] - return self._cached_decode_metadata - - -def _get_attn_bias( - attn_metadata: XFormersMetadata, - attn_type: str, -) -> Optional[AttentionBias]: - ''' - Extract appropriate attention bias from attention metadata - according to attention type. - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - * Appropriate attention bias value given the attention type - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - return attn_metadata.attn_bias - elif attn_type == AttentionType.ENCODER: - return attn_metadata.encoder_attn_bias - elif attn_type == AttentionType.ENCODER_DECODER: - return attn_metadata.cross_attn_bias - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -def _set_attn_bias( - attn_metadata: XFormersMetadata, - attn_bias: List[Optional[AttentionBias]], - attn_type: str, -) -> None: - ''' - Update appropriate attention bias field of attention metadata, - according to attention type. - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * attn_bias: The desired attention bias value - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - attn_metadata.attn_bias = attn_bias - elif attn_type == AttentionType.ENCODER: - attn_metadata.encoder_attn_bias = attn_bias - elif attn_type == AttentionType.ENCODER_DECODER: - attn_metadata.cross_attn_bias = attn_bias - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): - - _metadata_cls = XFormersMetadata - - -class XFormersImpl(AttentionImpl[XFormersMetadata]): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - If chunked prefill is enabled, prefill tokens and decode tokens can be - batched together in a flattened 1D query. - - |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| - |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| - - Currently, cuda graph is disabled for chunked prefill, meaning there's no - padding between prefill and decode tokens. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0 " - "XFORMERS backend.") - if logits_soft_cap is not None: - logger.warning_once("XFormers does not support logits soft cap. " - "Outputs may be slightly off.") - if use_irope: - logger.warning_once( - "Using irope in XFormers is not supported yet, it will fall" - " back to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = sliding_window - self.kv_cache_dtype = kv_cache_dtype - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - supported_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in supported_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}.") - - self.attn_type = attn_type - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: Optional[torch.Tensor], - value: Optional[torch.Tensor], - kv_cache: torch.Tensor, - attn_metadata: "XFormersMetadata", - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - output_block_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with xFormers and PagedAttention. - - For decoder-only models: query, key and value must be non-None. - - For encoder/decoder models: - * XFormersImpl.forward() may be invoked for both self- and cross- - attention layers. - * For self-attention: query, key and value must be non-None. - * For cross-attention: - * Query must be non-None - * During prefill, key and value must be non-None; key and value - get cached for use during decode. - * During decode, key and value may be None, since: - (1) key and value tensors were cached during prefill, and - (2) cross-attention key and value tensors do not grow during - decode - - A note on how the attn_type (attention type enum) argument impacts - attention forward() behavior: - - * DECODER: normal decoder-only behavior; - use decoder self-attention block table - * ENCODER: no KV caching; pass encoder sequence - attributes (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) to kernel, in lieu of decoder - sequence attributes (seq_lens/seq_lens_tensor/max_seq_len). - Used for encoder branch of encoder-decoder models. - * ENCODER_ONLY: no kv_caching, uses the normal attention - attributes (seq_lens/seq_lens_tensor/max_seq_len). - * ENCODER_DECODER: cross-attention behavior; - use cross-attention block table for caching KVs derived - from encoder hidden states; since KV sequence lengths - will match encoder sequence lengths, pass encoder sequence - attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ - max_encoder_seq_len) - - Args: - layer: Attention layer instance. - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache: KV cache tensor with shape - [2, num_blocks, block_size * num_kv_heads * head_size]. - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - output: Optional output tensor. - output_scale: Optional output scale tensor. - output_block_scale: Optional output block scale tensor. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for XFormersImpl") - - attn_type = self.attn_type - # Check that appropriate attention metadata attributes are - # selected for the desired attention type - if (attn_type == AttentionType.ENCODER - and (not attn_metadata.is_all_encoder_attn_metadata_set)): - raise AttributeError("Encoder attention requires setting " - "encoder metadata attributes.") - - elif (attn_type == AttentionType.ENCODER_DECODER - and (not attn_metadata.is_all_cross_attn_metadata_set)): - raise AttributeError("Encoder/decoder cross-attention " - "requires setting cross-attention " - "metadata attributes.") - - query = query.view(-1, self.num_heads, self.head_size) - if key is not None: - assert value is not None - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - else: - assert value is None - - # Self-attention vs. cross-attention will impact - # which KV cache memory-mapping & which - # seqlen datastructures we utilize - - if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): - # KV-cache during decoder-self- or - # encoder-decoder-cross-attention, but not - # during encoder attention. - # - # Even if there are no new key/value pairs to cache, - # we still need to break out key_cache and value_cache - # i.e. for later use by paged attention - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - if (key is not None) and (value is not None): - - if attn_type == AttentionType.ENCODER_DECODER: - # Update cross-attention KV cache (prefill-only) - # During cross-attention decode, key & value will be None, - # preventing this IF-statement branch from running - updated_slot_mapping = attn_metadata.cross_slot_mapping - else: - # Update self-attention KV cache (prefill/decode) - updated_slot_mapping = attn_metadata.slot_mapping - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory - # profiling run. - PagedAttention.write_to_paged_cache( - key, value, key_cache, value_cache, updated_slot_mapping, - self.kv_cache_dtype, layer._k_scale, layer._v_scale) - (num_prefill_query_tokens, num_prefill_kv_tokens, - num_decode_query_tokens) = \ - get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) - - output = torch.empty_like(query) - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_query_tokens:] - # QKV for prefill. - query = query[:num_prefill_query_tokens] - if key is not None and value is not None: - key = key[:num_prefill_kv_tokens] - value = value[:num_prefill_kv_tokens] - - assert query.shape[0] == num_prefill_query_tokens - assert decode_query.shape[0] == num_decode_query_tokens - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: - # normal attention. - # block tables are empty if the prompt does not have a cached - # prefix. - out = self._run_memory_efficient_xformers_forward( - query, key, value, prefill_meta, attn_type=attn_type) - assert out.shape == output[:num_prefill_query_tokens].shape - output[:num_prefill_query_tokens] = out - else: - assert attn_type != AttentionType.ENCODER_ONLY, ( - "Encoder-only models should not have prefix attention.") - - assert prefill_meta.query_start_loc is not None - assert prefill_meta.max_query_len is not None - - # prefix-enabled attention - # TODO(Hai) this triton kernel has regression issue (broke) to - # deal with different data types between KV and FP8 KV cache, - # to be addressed separately. - out = PagedAttention.forward_prefix( - query, - key, - value, - self.kv_cache_dtype, - key_cache, - value_cache, - prefill_meta.block_tables, - prefill_meta.query_start_loc, - prefill_meta.seq_lens_tensor, - prefill_meta.max_query_len, - self.alibi_slopes, - self.sliding_window, - layer._k_scale, - layer._v_scale, - ) - assert output[:num_prefill_query_tokens].shape == out.shape - output[:num_prefill_query_tokens] = out - - if decode_meta := attn_metadata.decode_metadata: - assert attn_type != AttentionType.ENCODER_ONLY, ( - "Encoder-only models should not have decode metadata.") - - ( - seq_lens_arg, - max_seq_len_arg, - block_tables_arg, - ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - - output[num_prefill_query_tokens:] = PagedAttention.forward_decode( - decode_query, - key_cache, - value_cache, - block_tables_arg, - seq_lens_arg, - max_seq_len_arg, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - layer._k_scale, - layer._v_scale, - ) - - # Reshape the output tensor. - return output.view(-1, self.num_heads * self.head_size) - - def _run_memory_efficient_xformers_forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: XFormersMetadata, - attn_type: str = AttentionType.DECODER, - ) -> torch.Tensor: - """Attention for 1D query of multiple prompts. Multiple prompt - tokens are flattened in to `query` input. - - See https://facebookresearch.github.io/xformers/components/ops.html - for API spec. - - Args: - query: shape = [num_prefill_tokens, num_heads, head_size] - key: shape = [num_prefill_tokens, num_kv_heads, head_size] - value: shape = [num_prefill_tokens, num_kv_heads, head_size] - attn_metadata: Metadata for attention. - attn_type: Select attention type, between encoder attention, - decoder self-attention, or encoder/decoder cross- - attention. Defaults to decoder self-attention, - which is the vLLM default generally - """ - - original_query = query - if self.num_kv_heads != self.num_heads: - # GQA/MQA requires the shape [B, M, G, H, K]. - # Note that the output also has the same shape (which is different - # from a spec from the doc). - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - # Set attention bias if not provided. This typically happens at - # the very attention layer of every iteration. - # FIXME(woosuk): This is a hack. - attn_bias = _get_attn_bias(attn_metadata, attn_type) - if attn_bias is None: - if self.alibi_slopes is None: - - # Cross attention block of decoder branch of encoder-decoder - # model uses seq_lens for dec / encoder_seq_lens for enc - if (attn_type == AttentionType.ENCODER_DECODER): - assert attn_metadata.seq_lens is not None - assert attn_metadata.encoder_seq_lens is not None - - # Cross-attention mask is non-causal - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens, - attn_metadata.encoder_seq_lens, - device=query.device) - - # Encoder branch of encoder-decoder model uses - # attn_metadata.encoder_seq_lens - elif attn_type == AttentionType.ENCODER: - - assert attn_metadata.encoder_seq_lens is not None - - # Encoder self-attention mask is non-causal - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.encoder_seq_lens, device=query.device) - - # Self-attention block of encoder-only model just - # uses the seq_lens directly. - elif attn_type == AttentionType.ENCODER_ONLY: - assert attn_metadata.seq_lens is not None - - # Encoder self-attention mask is non-causal - attn_bias = BlockDiagonalMask.from_seqlens( - attn_metadata.seq_lens, device=query.device) - - # Self-attention block of decoder branch just - # uses the seq_lens directly - elif attn_type == AttentionType.DECODER: - assert attn_metadata.seq_lens is not None - - # Decoder self-attention mask is causal - attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.seq_lens, device=query.device) - else: - raise ValueError("Unknown AttentionType: %s", attn_type) - - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention( - self.sliding_window) - attn_bias = [attn_bias] - else: - assert attn_type == AttentionType.DECODER - assert attn_metadata.seq_lens is not None - attn_bias = _make_alibi_bias(self.alibi_slopes, - self.num_kv_heads, query.dtype, - attn_metadata.seq_lens) - - _set_attn_bias(attn_metadata, attn_bias, attn_type) - - # No alibi slopes. - # TODO(woosuk): Too many view operations. Let's try to reduce - # them in the future for code readability. - if self.alibi_slopes is None: - # Add the batch dimension. - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - out = xops.memory_efficient_attention_forward( - query, - key, - value, - attn_bias=attn_bias[0], - p=0.0, - scale=self.scale) - return out.view_as(original_query) - - # Attention with alibi slopes. - # FIXME(woosuk): Because xformers does not support dynamic sequence - # lengths with custom attention bias, we process each prompt one by - # one. This is inefficient, especially when we have many short prompts. - assert attn_metadata.seq_lens is not None - output = torch.empty_like(original_query) - start = 0 - for i, seq_len in enumerate(attn_metadata.seq_lens): - end = start + seq_len - out = xops.memory_efficient_attention_forward( - query[None, start:end], - key[None, start:end], - value[None, start:end], - attn_bias=attn_bias[i], - p=0.0, - scale=self.scale) - # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out.view_as(original_query[start:end])) - start += seq_len - return output - - -def _make_alibi_bias( - alibi_slopes: torch.Tensor, - num_kv_heads: int, - dtype: torch.dtype, - seq_lens: List[int], -) -> List[AttentionBias]: - attn_biases: List[AttentionBias] = [] - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - # Calculate a matrix where each element represents ith element- jth - # element. - bias = bias[None, :] - bias[:, None] - - padded_len = (seq_len + 7) // 8 * 8 - num_heads = alibi_slopes.shape[0] - bias = torch.empty( - 1, # batch size - num_heads, - seq_len, - padded_len, - device=alibi_slopes.device, - dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) - - return attn_biases diff --git a/vllm/config/model.py b/vllm/config/model.py index 95fe52883db0..33e5d3ea04a4 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -32,8 +32,7 @@ from vllm.transformers_utils.runai_utils import (ObjectStorageModel, is_runai_obj_uri) from vllm.transformers_utils.utils import maybe_model_redirect -from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, LayerBlockType, - LazyLoader, common_broadcastable_dtype) +from vllm.utils import LayerBlockType, LazyLoader, common_broadcastable_dtype if TYPE_CHECKING: from transformers import PretrainedConfig @@ -1103,10 +1102,6 @@ def verify_dual_chunk_attention_config( self.hf_config.dual_chunk_attention_config[ "sparse_attention_enabled"] = True - if envs.VLLM_ATTENTION_BACKEND != STR_DUAL_CHUNK_FLASH_ATTN_VAL: - raise ValueError("please set VLLM_ATTENTION_BACKEND to " - f"{STR_DUAL_CHUNK_FLASH_ATTN_VAL}") - def verify_with_parallel_config( self, parallel_config: ParallelConfig, diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 911d77ba36fa..efa4c9abf47f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -44,7 +44,7 @@ def get_model_args(self, model_executable: torch.nn.Module): # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading # to a kv_cache shape of [2, num_blks, blk_size, # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. - # For more details, see vllm/attention/backends/mla/common.py. + # For more details, see vllm/v1/attention/backends/mla/common.py. if self.is_deepseek_mla and self.use_mla_opt: head_size = model_config.kv_lora_rank + \ model_config.qk_rope_head_dim diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7e00260caa39..b09d43f70558 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -44,8 +44,8 @@ from vllm.transformers_utils.config import (get_model_path, is_interleaved, maybe_override_with_speculators) from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, - GiB_bytes, get_ip, is_in_ray_actor) +from vllm.utils import (FlexibleArgumentParser, GiB_bytes, get_ip, + is_in_ray_actor) from vllm.v1.sample.logits_processor import LogitsProcessor # yapf: enable @@ -1163,17 +1163,6 @@ def create_engine_config( self._set_default_args_v0(model_config) assert self.enable_chunked_prefill is not None - if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]: - assert self.enforce_eager, ( - "Cuda graph is not supported with DualChunkFlashAttention. " - "To run the model in eager mode, set 'enforce_eager=True' " - "or use '--enforce-eager' in the CLI.") - assert current_platform.is_cuda(), ( - "DualChunkFlashAttention is only supported on CUDA platform.") - assert not use_v1, ( - "DualChunkFlashAttention is not supported on V1 engine. " - "To run the model in V0 engine, try set 'VLLM_USE_V1=0'") - sliding_window: Optional[int] = None if not is_interleaved(model_config.hf_text_config): # Only set CacheConfig.sliding_window if the model is all sliding diff --git a/vllm/envs.py b/vllm/envs.py index 3991a789d80f..cbd1d5474e60 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -529,7 +529,6 @@ def get_vllm_port() -> Optional[int]: # - "TORCH_SDPA": use torch.nn.MultiheadAttention # - "FLASH_ATTN": use FlashAttention # - "XFORMERS": use XFormers - # - "ROCM_FLASH": use ROCmFlashAttention # - "FLASHINFER": use flashinfer # - "FLASHMLA": use FlashMLA # - "FLASH_ATTN_MLA": use FlashAttention for MLA diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index c926e17a2c19..7f376b70a7ae 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -53,13 +53,18 @@ class Mamba2Metadata: def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: """Returns the appropriate metadata classes for the current platform.""" if current_platform.is_rocm(): - from vllm.attention.backends.rocm_flash_attn import ( - ROCmFlashAttentionMetadata) - return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata) - elif current_platform.is_cuda(): - from vllm.attention.backends.flash_attn import FlashAttentionMetadata - from vllm.attention.backends.xformers import XFormersMetadata - return (FlashAttentionMetadata, XFormersMetadata, + from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata) + from vllm.v1.attention.backends.triton_attn import ( + TritonAttentionMetadata) + return (AiterFlashAttentionMetadata, TritonAttentionMetadata, + PlaceholderAttentionMetadata) + if current_platform.is_cuda(): + from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionMetadata) + from vllm.v1.attention.backends.xformers import ( + XFormersAttentionMetadata) + return (FlashAttentionMetadata, XFormersAttentionMetadata, PlaceholderAttentionMetadata) raise ValueError( f"Unsupported platform for Mamba2: {current_platform.device_type}") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index a99a6679a569..415d36c681d8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -478,7 +478,8 @@ class DeepseekV2MLAAttention(nn.Module): Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py + For more info see MLACommonImpl in: + vllm/v1/attention/backends/mla/utils.py """ def __init__( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c263e2afe83b..05f129f513a0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -226,8 +226,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, has_sink) -> str: if use_mla: - # TODO(lucas): refactor to be more concise - # we should probably consider factoring out V1 here + if not use_v1: + raise RuntimeError( + "MLA attention backends require the V1 engine. " + "Set VLLM_USE_V1=1 to enable them.") from vllm.attention.ops.flashmla import is_flashmla_supported from vllm.attention.utils.fa_utils import flash_attn_supports_mla @@ -246,35 +248,17 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, use_triton = selected_backend == _Backend.TRITON_MLA or ( selected_backend is None) - def _get_version(name, import_suffix) -> str: - if use_v1: - logger.info_once(f"Using {name} backend on V1 engine.") - return f"vllm.v1.attention.backends.mla.{import_suffix}" - else: - logger.info_once(f"Using {name} backend.") - return f"vllm.attention.backends.{import_suffix}" - if use_cutlassmla: - if use_v1: - logger.info_once("Using Cutlass MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "cutlass_mla.CutlassMLABackend") - else: - logger.warning( - "Cutlass MLA backend is only supported on V1 engine") + logger.info_once("Using Cutlass MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "cutlass_mla.CutlassMLABackend") if use_flashinfermla: - if use_v1: - from vllm.v1.attention.backends.utils import ( - set_kv_cache_layout) - set_kv_cache_layout("HND") - logger.info_once( - "Using FlashInfer MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "flashinfer_mla.FlashInferMLABackend") - else: - logger.warning( - "FlashInfer MLA backend is only supported on V1 engine" - ) + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + set_kv_cache_layout("HND") + logger.info_once("Using FlashInfer MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashinfer_mla.FlashInferMLABackend") if use_flashmla: if block_size != 64: logger.warning( @@ -282,20 +266,18 @@ def _get_version(name, import_suffix) -> str: " (currently only supports block size 64).", block_size) else: - return _get_version("FlashMLA", "flashmla.FlashMLABackend") - if use_flashattn: - if use_v1: - logger.info_once( - "Using FlashAttention MLA backend on V1 engine.") + logger.info_once("Using FlashMLA backend on V1 engine.") return ("vllm.v1.attention.backends.mla." - "flashattn_mla.FlashAttnMLABackend") - else: - logger.warning( - "FlashAttention MLA backend is only supported on V1 " - "engine.") + "flashmla.FlashMLABackend") + if use_flashattn: + logger.info_once( + "Using FlashAttention MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashattn_mla.FlashAttnMLABackend") if use_triton: - return _get_version("Triton MLA", - "triton_mla.TritonMLABackend") + logger.info_once("Using Triton MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "triton_mla.TritonMLABackend") if use_v1: FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 @@ -382,78 +364,9 @@ def _get_version(name, import_suffix) -> str: ) return FLEX_ATTENTION_V1 - # Backends for V0 engine - if selected_backend == _Backend.XFORMERS: - logger.info("Using XFormers backend.") - return "vllm.attention.backends.xformers.XFormersBackend" - elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN: - logger.info("Using DualChunkFlashAttention backend.") - return ("vllm.attention.backends.dual_chunk_flash_attn." - "DualChunkFlashAttentionBackend") - elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN: - logger.info("Using DifferentialFlashAttention backend.") - return ("vllm.attention.backends.differential_flash_attn." - "DifferentialFlashAttentionBackend") - elif selected_backend == _Backend.FLASH_ATTN: - pass - elif selected_backend: - raise ValueError( - f"Invalid attention backend for {cls.device_name}, " - f"with use_v1: {use_v1} use_mla: {use_mla}") - - target_backend = _Backend.FLASH_ATTN - if not cls.has_device_capability(80): - # Volta and Turing NVIDIA GPUs. - logger.info( - "Cannot use FlashAttention-2 backend for Volta and Turing " - "GPUs.") - target_backend = _Backend.XFORMERS - elif dtype not in (torch.float16, torch.bfloat16): - logger.info( - "Cannot use FlashAttention-2 backend for dtype other than " - "torch.float16 or torch.bfloat16.") - target_backend = _Backend.XFORMERS - elif block_size % 16 != 0: - logger.info( - "Cannot use FlashAttention-2 backend for block size not " - "divisible by 16.") - target_backend = _Backend.XFORMERS - - # FlashAttn is valid for the model, checking if the package is - # installed. - if target_backend == _Backend.FLASH_ATTN: - try: - import vllm.vllm_flash_attn # noqa: F401 - from vllm.attention.backends.flash_attn import ( # noqa: F401 - FlashAttentionBackend, flash_attn_supports_fp8) - - supported_sizes = \ - FlashAttentionBackend.get_supported_head_sizes() - if head_size not in supported_sizes: - logger.info( - "Cannot use FlashAttention-2 backend for head size %d.", - head_size) - target_backend = _Backend.XFORMERS - fp8_kv_cache = (kv_cache_dtype is not None - and kv_cache_dtype.startswith("fp8")) - if (fp8_kv_cache and not flash_attn_supports_fp8()): - logger.info( - "Cannot use FlashAttention backend for FP8 KV cache.") - target_backend = _Backend.XFORMERS - except ImportError: - logger.info( - "Cannot use FlashAttention-2 backend because the " - "vllm.vllm_flash_attn package is not found. " - "Make sure that vllm_flash_attn was built and installed " - "(on by default).") - target_backend = _Backend.XFORMERS - - if target_backend == _Backend.XFORMERS: - logger.info("Using XFormers backend.") - return "vllm.attention.backends.xformers.XFormersBackend" - - logger.info("Using Flash Attention backend.") - return "vllm.attention.backends.flash_attn.FlashAttentionBackend" + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend.") @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index dce2924ac7a9..9470434aa428 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -191,6 +191,11 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, has_sink) -> str: if use_mla: + if not use_v1: + raise RuntimeError( + "MLA attention backends require the V1 engine. " + "Set VLLM_USE_V1=1 to enable them.") + from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( is_aiter_mla_enabled) @@ -201,39 +206,24 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if selected_backend == _Backend.TRITON_MLA: if block_size != 1: - if use_v1: - logger.info_once( - "Using Triton MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "triton_mla.TritonMLABackend") - else: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501 - else: - raise ValueError( - f" The selected backend, {selected_backend.name}," - f"does not support block size {block_size}.") - elif selected_backend == _Backend.ROCM_AITER_MLA \ - or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1: + logger.info_once("Using Triton MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "triton_mla.TritonMLABackend") + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"does not support block size {block_size}.") + if selected_backend in (_Backend.ROCM_AITER_MLA, + _Backend.ROCM_AITER_MLA_VLLM_V1): if block_size == 1: - if use_v1: - logger.info("Using AITER MLA backend on V1 engine.") - return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 - else: - logger.info("Using AITER MLA backend") - return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 - else: - raise ValueError( - f" The selected backend, {selected_backend.name}," - f"does not support block size {block_size}." - "(currently only supports block size 1)") - else: + logger.info("Using AITER MLA backend on V1 engine.") + return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 raise ValueError( f" The selected backend, {selected_backend.name}," - f"is not MLA type while requested for MLA backend.") - - if selected_backend is None or selected_backend == _Backend.FLASH_ATTN: - selected_backend = _Backend.ROCM_FLASH + f"does not support block size {block_size}." + "(currently only supports block size 1)") + raise ValueError( + f" The selected backend, {selected_backend.name}," + f"is not MLA type while requested for MLA backend.") if envs.VLLM_USE_V1: if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \ @@ -245,14 +235,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using Triton Attention backend on V1 engine.") return ("vllm.v1.attention.backends." "triton_attn.TritonAttentionBackend") - if selected_backend == _Backend.ROCM_FLASH: - if not cls.has_device_capability(90): - # not Instinct series GPUs. - logger.info("flash_attn is not supported on NAVI GPUs.") - else: - logger.info("%s is not supported in AMD GPUs.", selected_backend) - logger.info("Using ROCmFlashAttention backend.") - return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend.") @classmethod def set_device(cls, device: torch.device) -> None: diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 968bba664f0a..834ec9b1d30b 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -157,10 +157,8 @@ # register, corresponding to possible backends STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" -STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" -STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" MB_bytes = 1_000_000 From a271abf9d9ebed4e4f08e47899119395a4de6da2 Mon Sep 17 00:00:00 2001 From: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Date: Sun, 21 Sep 2025 16:06:16 -0700 Subject: [PATCH 11/17] [Bugfix][V0 Deprecation][CI] use async mock and await for async method (#25325) Signed-off-by: Yang --- .../entrypoints/openai/test_lora_resolvers.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index e2c83b9c4004..9d5ee84a1956 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from http import HTTPStatus from typing import Optional -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest @@ -83,20 +83,31 @@ def register_mock_resolver(): def mock_serving_setup(): """Provides a mocked engine and serving completion instance.""" mock_engine = MagicMock(spec=AsyncLLM) - mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False - def mock_add_lora_side_effect(lora_request: LoRARequest): + tokenizer = get_tokenizer(MODEL_NAME) + mock_engine.get_tokenizer = AsyncMock(return_value=tokenizer) + + async def mock_add_lora_side_effect(lora_request: LoRARequest): """Simulate engine behavior when adding LoRAs.""" if lora_request.lora_name == "test-lora": # Simulate successful addition - return - elif lora_request.lora_name == "invalid-lora": + return True + if lora_request.lora_name == "invalid-lora": # Simulate failure during addition (e.g. invalid format) raise ValueError(f"Simulated failure adding LoRA: " f"{lora_request.lora_name}") + return True + + mock_engine.add_lora = AsyncMock(side_effect=mock_add_lora_side_effect) + + async def mock_generate(*args, **kwargs): + for _ in []: + yield _ + + mock_engine.generate = MagicMock(spec=AsyncLLM.generate, + side_effect=mock_generate) - mock_engine.add_lora.side_effect = mock_add_lora_side_effect mock_engine.generate.reset_mock() mock_engine.add_lora.reset_mock() @@ -131,7 +142,7 @@ async def test_serving_completion_with_lora_resolver(mock_serving_setup, with suppress(Exception): await serving_completion.create_completion(req_found) - mock_engine.add_lora.assert_called_once() + mock_engine.add_lora.assert_awaited_once() called_lora_request = mock_engine.add_lora.call_args[0][0] assert isinstance(called_lora_request, LoRARequest) assert called_lora_request.lora_name == lora_model_name @@ -157,7 +168,7 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup, response = await serving_completion.create_completion(req) - mock_engine.add_lora.assert_not_called() + mock_engine.add_lora.assert_not_awaited() mock_engine.generate.assert_not_called() assert isinstance(response, ErrorResponse) @@ -181,7 +192,7 @@ async def test_serving_completion_resolver_add_lora_fails( response = await serving_completion.create_completion(req) # Assert add_lora was called before the failure - mock_engine.add_lora.assert_called_once() + mock_engine.add_lora.assert_awaited_once() called_lora_request = mock_engine.add_lora.call_args[0][0] assert isinstance(called_lora_request, LoRARequest) assert called_lora_request.lora_name == invalid_model From 73f2bef0cae40efa7f8cf58dc00a16ee36a1441c Mon Sep 17 00:00:00 2001 From: Deboleina Date: Sun, 21 Sep 2025 19:07:11 -0400 Subject: [PATCH 12/17] Multimodal - audio tests (#25285) Signed-off-by: Debolina Roy --- tests/multimodal/test_audio.py | 140 +++++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 tests/multimodal/test_audio.py diff --git a/tests/multimodal/test_audio.py b/tests/multimodal/test_audio.py new file mode 100644 index 000000000000..ba39af845041 --- /dev/null +++ b/tests/multimodal/test_audio.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# test_audio.py +import base64 +from pathlib import Path +from unittest.mock import patch + +import numpy as np +import pytest + +from vllm.multimodal.audio import (AudioMediaIO, AudioResampler, + resample_audio_librosa, + resample_audio_scipy) + + +@pytest.fixture +def dummy_audio(): + return np.array([0.0, 0.1, 0.2, 0.3, 0.4], dtype=float) + + +def test_resample_audio_librosa(dummy_audio): + with patch("vllm.multimodal.audio.librosa.resample") as mock_resample: + mock_resample.return_value = dummy_audio * 2 + out = resample_audio_librosa(dummy_audio, + orig_sr=44100, + target_sr=22050) + mock_resample.assert_called_once_with(dummy_audio, + orig_sr=44100, + target_sr=22050) + assert np.all(out == dummy_audio * 2) + + +def test_resample_audio_scipy(dummy_audio): + out_down = resample_audio_scipy(dummy_audio, orig_sr=4, target_sr=2) + out_up = resample_audio_scipy(dummy_audio, orig_sr=2, target_sr=4) + out_same = resample_audio_scipy(dummy_audio, orig_sr=4, target_sr=4) + + assert len(out_down) == 3 + assert len(out_up) == 10 + assert np.all(out_same == dummy_audio) + + +@pytest.mark.xfail( + reason="resample_audio_scipy is buggy for non-integer ratios") +def test_resample_audio_scipy_non_integer_ratio(dummy_audio): + out = resample_audio_scipy(dummy_audio, orig_sr=5, target_sr=3) + + expected_len = int(round(len(dummy_audio) * 3 / 5)) + assert len(out) == expected_len + + assert isinstance(out, np.ndarray) + assert np.isfinite(out).all() + + +def test_audio_resampler_librosa_calls_resample(dummy_audio): + resampler = AudioResampler(target_sr=22050, method="librosa") + with patch( + "vllm.multimodal.audio.resample_audio_librosa") as mock_resample: + mock_resample.return_value = dummy_audio + out = resampler.resample(dummy_audio, orig_sr=44100) + mock_resample.assert_called_once_with(dummy_audio, + orig_sr=44100, + target_sr=22050) + assert np.all(out == dummy_audio) + + +def test_audio_resampler_scipy_calls_resample(dummy_audio): + resampler = AudioResampler(target_sr=22050, method="scipy") + with patch("vllm.multimodal.audio.resample_audio_scipy") as mock_resample: + mock_resample.return_value = dummy_audio + out = resampler.resample(dummy_audio, orig_sr=44100) + mock_resample.assert_called_once_with(dummy_audio, + orig_sr=44100, + target_sr=22050) + assert np.all(out == dummy_audio) + + +def test_audio_resampler_invalid_method(dummy_audio): + resampler = AudioResampler(target_sr=22050, method="invalid") + with pytest.raises(ValueError): + resampler.resample(dummy_audio, orig_sr=44100) + + +def test_audio_resampler_no_target_sr(dummy_audio): + resampler = AudioResampler(target_sr=None) + with pytest.raises(RuntimeError): + resampler.resample(dummy_audio, orig_sr=44100) + + +@pytest.fixture +def dummy_audio_bytes(): + return b"FAKEAUDIOBYTES" + + +def test_audio_media_io_load_bytes(dummy_audio_bytes): + audio_io = AudioMediaIO() + with patch("vllm.multimodal.audio.librosa.load") as mock_load: + mock_load.return_value = (np.array([0.1, 0.2]), 16000) + out = audio_io.load_bytes(dummy_audio_bytes) + mock_load.assert_called_once() + assert isinstance(out[0], np.ndarray) + assert out[1] == 16000 + + +def test_audio_media_io_load_base64(dummy_audio_bytes): + audio_io = AudioMediaIO() + encoded = base64.b64encode(dummy_audio_bytes).decode("utf-8") + with patch.object(AudioMediaIO, "load_bytes") as mock_load_bytes: + mock_load_bytes.return_value = (np.array([0.1, 0.2]), 16000) + out = audio_io.load_base64("audio/wav", encoded) + mock_load_bytes.assert_called_once() + assert isinstance(out[0], np.ndarray) + assert out[1] == 16000 + + +def test_audio_media_io_load_file(): + audio_io = AudioMediaIO() + path = Path("/fake/path.wav") + with patch("vllm.multimodal.audio.librosa.load") as mock_load: + mock_load.return_value = (np.array([0.1, 0.2]), 16000) + out = audio_io.load_file(path) + mock_load.assert_called_once_with(path, sr=None) + assert isinstance(out[0], np.ndarray) + assert out[1] == 16000 + + +def test_audio_media_io_encode_base64(dummy_audio): + audio_io = AudioMediaIO() + media = (dummy_audio, 16000) + with patch("vllm.multimodal.audio.soundfile.write") as mock_write: + + def write_to_buffer(buffer, *_args, **_kwargs): + buffer.write(b"dummy_wav_data") + + mock_write.side_effect = write_to_buffer + + out = audio_io.encode_base64(media) + decoded = base64.b64decode(out) + assert decoded == b"dummy_wav_data" + mock_write.assert_called_once() From 1ffb41287d754ae1456f83223be73cd497f41a20 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 21 Sep 2025 19:24:40 -0700 Subject: [PATCH 13/17] [Model] Support Dots OCR (#24645) Signed-off-by: Roger Wang Co-authored-by: yinz-aizip --- docs/models/supported_models.md | 1 + examples/offline_inference/vision_language.py | 18 + tests/models/registry.py | 2 + vllm/model_executor/models/dots_ocr.py | 824 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/dotsocr.py | 69 ++ 7 files changed, 917 insertions(+) create mode 100644 vllm/model_executor/models/dots_ocr.py create mode 100644 vllm/transformers_utils/configs/dotsocr.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index cbc0a56a645e..9d288667a318 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -352,6 +352,7 @@ th { | `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | | `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ | +| `DotsOCRForCausalLM` | dots_ocr | `rednote-hilab/dots.ocr` | | ✅︎ | ✅︎ | | `Ernie4_5ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ | | `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index de3f3afc1794..f8ddb5a22b31 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -126,6 +126,23 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: ) +# Dots-OCR +def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions] + engine_args = EngineArgs( + model="rednote-hilab/dots.ocr", + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1676,6 +1693,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: "aya_vision": run_aya_vision, "blip-2": run_blip2, "chameleon": run_chameleon, + "dots_ocr": run_dots_ocr, "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, "ernie45_vl": run_ernie45_vl, diff --git a/tests/models/registry.py b/tests/models/registry.py index e9cc5170ade7..29b6980aaa42 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -448,6 +448,8 @@ def check_available_online( max_transformers_version="4.48", # noqa: E501 transformers_version_reason="HF model is not compatible.", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 + "DotsOCRForCausalLM": _HfExamplesInfo("rednote-hilab/dots.ocr", + trust_remote_code=True), "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501 trust_remote_code=True), diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py new file mode 100644 index 000000000000..04fa5584199a --- /dev/null +++ b/vllm/model_executor/models/dots_ocr.py @@ -0,0 +1,824 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable, Mapping +from typing import Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm +from transformers.modeling_utils import PreTrainedModel +from transformers.models.qwen2_vl import Qwen2VLProcessor + +from vllm.attention.layer import check_upstream_fa_availability +from vllm.config import VllmConfig +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, + SupportsMultiModal, + SupportsPP) +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder, + Qwen2VLMultiModalProcessor, + Qwen2VLProcessingInfo) +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + init_vllm_registered_model, + maybe_prefix, + merge_multimodal_embeddings) +from vllm.model_executor.models.vision import get_vit_attn_backend +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalDataDict +from vllm.platforms import _Backend +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig, + DotsVisionConfig) + +IMAGE_TOKEN = "<|imgpad|>" + + +class DotsOCRImagePixelInputs(TypedDict): + type: Literal["pixel_values", "image_grid_thw"] + + pixel_values: torch.Tensor + image_grid_thw: torch.Tensor + + +class DotsOCRImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds", "image_grid_thw"] + image_embeds: torch.Tensor + """Supported types: + - List[`torch.Tensor`]: A list of tensors holding all images' features. + Each tensor holds an image's features. + - `torch.Tensor`: A tensor holding all images' features + (concatenation of all images' feature tensors). + Tensor shape: `(num_image_features, hidden_size)` + - `num_image_features` varies based on + the number and resolution of the images. + - `hidden_size` must match the hidden size of language model backbone. + """ + + image_grid_thw: torch.Tensor + + +DotsOCRImageInputs = Union[DotsOCRImagePixelInputs, + DotsOCRImageEmbeddingInputs] + + +class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + return IMAGE_TOKEN * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_image_size_with_most_features( # noqa: E501 + ) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + } + + +class DotsOCRProcessingInfo(Qwen2VLProcessingInfo): + + def get_hf_config(self) -> DotsOCRConfig: + config = self.ctx.get_hf_config() + if not config.__class__.__name__ == 'DotsOCRConfig': + raise TypeError(f"Expected DotsOCRConfig, got {type(config)}") + + if hasattr(config, "vision_config") and isinstance( + config.vision_config, dict): + config.vision_config = DotsVisionConfig(**config.vision_config) + + return config + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + max_image_tokens = self.get_max_image_tokens() + return {"image": max_image_tokens} + + def get_hf_processor( + self, + **kwargs: object, + ) -> Qwen2VLProcessor: + self.get_tokenizer( + ).image_token = IMAGE_TOKEN # Ensure image token is set + processor = self.ctx.get_hf_processor( + Qwen2VLProcessor, + **kwargs, + ) + processor.image_token = IMAGE_TOKEN + processor.video_token = "<|video_pad|>" + return processor + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, + freqs: torch.Tensor) -> torch.Tensor: + orig_dtype = tensor.dtype + tensor = tensor.float() + + cos = freqs.cos() + sin = freqs.sin() + + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() + + output = (tensor * cos) + (rotate_half(tensor) * sin) + + output = output.to(orig_dtype) + + return output + + +class VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class PatchMerger(nn.Module): + + def __init__( + self, + dim: int, + context_dim: int, + spatial_merge_size: int = 2, + pre_norm="layernorm", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.pre_norm = pre_norm + if self.pre_norm == "layernorm": + self.ln_q = LayerNorm(context_dim, eps=1e-6) + elif self.pre_norm == "rmsnorm": + self.ln_q = RMSNorm(context_dim, eps=1e-6) + else: + print("no norm in patch merger") + + self.mlp = nn.Sequential( + ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + return_bias=False, + disable_tp=True), + nn.GELU(), + RowParallelLinear(self.hidden_size, + dim, + bias=True, + return_bias=False, + disable_tp=True), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.pre_norm: + x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) + else: + x = self.mlp(x.view(-1, self.hidden_size)) + return x + + +class DotsVisionAttention(nn.Module): + + def __init__(self, + config, + dim: int, + num_heads: int = 16, + bias: bool = True, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + from vllm.distributed import (parallel_state, + tensor_model_parallel_all_gather) + from vllm.distributed import utils as dist_utils + + self.embed_dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.num_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size) + + # qkv/proj follow Qwen2-VL style; bias controlled by arg + self.qkv = QKVParallelLinear(hidden_size=dim, + head_size=dim // num_heads, + total_num_heads=num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + self.proj = RowParallelLinear(input_size=dim, + output_size=dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.proj") + self._all_gather = tensor_model_parallel_all_gather + self._split_last = dist_utils.split_tensor_along_last_dim + + # Select attention backend + self.attn_backend = get_vit_attn_backend(self.head_dim, + torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Unsupported vision attention backend: {self.attn_backend}") + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + + def _split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # qkv: [S, B, 3*dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = self._all_gather(qkv) + q, k, v = qkv.chunk(3, dim=2) + if self.tp_size > 1: + q = self._split_last(q, num_partitions=self.tp_size)[self.tp_rank] + k = self._split_last(k, num_partitions=self.tp_size)[self.tp_rank] + v = self._split_last(v, num_partitions=self.tp_size)[self.tp_rank] + new_shape = (seq_len, bs, self.num_heads_per_partition, self.head_dim) + return (q.view(*new_shape), k.view(*new_shape), v.view(*new_shape)) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + *, + max_seqlen: Optional[int] = None, + seqlens: Optional[list[int]] = None, + ) -> torch.Tensor: + # [S, C] -> [S, B=1, C] + x = hidden_states.unsqueeze(1) + x, _ = self.qkv(x) + q, k, v = self._split_qkv(x) + bs = q.shape[1] + # [S,B,H,D] -> [B,S,H,D] + q = q.permute(1, 0, 2, 3).contiguous() + k = k.permute(1, 0, 2, 3).contiguous() + v = v.permute(1, 0, 2, 3).contiguous() + + if rotary_pos_emb is not None: + qk_concat = torch.cat([q, k], dim=0) + qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) + q, k = torch.chunk(qk_rotated, 2, dim=0) + + if self.is_flash_attn_backend: + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func + q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3]) + k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3]) + v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3]) + output = flash_attn_varlen_func(q_, + k_, + v_, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) + context_layer = output.view(bs, -1, self.num_heads_per_partition, + self.head_dim) + elif self.attn_backend == _Backend.TORCH_SDPA: + outputs = [] + for i in range(1, len(cu_seqlens)): + s = int(cu_seqlens[i - 1]) + e = int(cu_seqlens[i]) + q_i = q[:, s:e].permute(0, 2, 1, 3) + k_i = k[:, s:e].permute(0, 2, 1, 3) + v_i = v[:, s:e].permute(0, 2, 1, 3) + out_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + out_i = out_i.permute(0, 2, 1, 3) + outputs.append(out_i) + context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None, + device=q.device) + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + else: + raise RuntimeError("Unsupported attention backend") + + # [B,S,H,D] -> [S,B,H*D] -> [S, C] + context_layer = context_layer.permute(1, 0, 2, 3).contiguous() + context_layer = context_layer.view(context_layer.shape[0], bs, -1) + out, _ = self.proj(context_layer) + return out.squeeze(1) + + +class DotsSwiGLUFFN(nn.Module): + + def __init__(self, + config, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + hidden_features = config.intermediate_size + in_features = config.embed_dim + bias = config.use_bias + + # Referenced aimv2.py AIMv2SwiGLUFFN + self.fc13 = MergedColumnParallelLinear(in_features, + [hidden_features] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc13", + disable_tp=True) + self.fc2 = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + disable_tp=True) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc13(x) + x = self.act_fn(x) + x, _ = self.fc2(x) + return x + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params = dict(self.named_parameters()) + loaded: set[str] = set() + for name, w in weights: + # Map fc1 -> fc13 (shard 0) + if name.startswith("fc1."): + tgt = name.replace("fc1.", "fc13.") + if tgt in params: + params[tgt].weight_loader(params[tgt], w, 0) + loaded.add(tgt) + continue + # Map fc3 -> fc13 (shard 1) + if name.startswith("fc3."): + tgt = name.replace("fc3.", "fc13.") + if tgt in params: + params[tgt].weight_loader(params[tgt], w, 1) + loaded.add(tgt) + continue + # Pass-through for fc2 and others + if name in params: + params[name].weight_loader(params[name], w) + loaded.add(name) + return loaded + + +class DotsPatchEmbed(nn.Module): + + def __init__(self, config): + super().__init__() + self.num_channels = config.num_channels + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.embed_dim = config.embed_dim + self.config = config + self.proj = nn.Conv2d( + config.num_channels, + config.embed_dim, + kernel_size=(config.patch_size, config.patch_size), + stride=(config.patch_size, config.patch_size), + ) + self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + + def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: + x = x.view(-1, self.num_channels, self.temporal_patch_size, + self.patch_size, self.patch_size)[:, :, 0] + x = self.proj(x).view(-1, self.embed_dim) + x = self.norm(x) + return x + + +class DotsViTPreprocessor(nn.Module): + + def __init__(self, config): + super().__init__() + self.patch_h = config.patch_size + self.patch_w = config.patch_size + self.embed_dim = config.embed_dim + self.config = config + self.patchifier = DotsPatchEmbed(config) + + def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor: + tokens = self.patchifier(x, grid_thw) + return tokens + + +class DotsVisionBlock(nn.Module): + + def __init__(self, + config, + *, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.attn = DotsVisionAttention( + config, + config.embed_dim, + num_heads=config.num_attention_heads, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + self.mlp = DotsSwiGLUFFN(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) + + def forward(self, + hidden_states: torch.Tensor, + *, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, + seqlens: Optional[list[int]] = None) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class DotsVisionTransformer(PreTrainedModel): + + def __init__( + self, + config: DotsVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + require_post_norm: Optional[bool] = None, + prefix: str = "", + ) -> None: + super().__init__(config) + self.config = config + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = DotsViTPreprocessor(config) + + head_dim = config.embed_dim // config.num_attention_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + + # Keep blocks for compatibility with other vision towers + num_layers = (config.num_hidden_layers if num_hidden_layers_override + is None else num_hidden_layers_override) + self.blocks = nn.ModuleList([ + DotsVisionBlock(config, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{i}") + for i in range(num_layers) + ]) + if require_post_norm is None: + require_post_norm = (len(self.blocks) == config.num_hidden_layers) + if require_post_norm and self.config.post_norm: + self.post_trunk_norm = RMSNorm(config.embed_dim, + eps=config.rms_norm_eps) + else: + self.post_trunk_norm = None + + self.merger = PatchMerger( + dim=config.hidden_size, + context_dim=config.embed_dim, + spatial_merge_size=config.spatial_merge_size, + ) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.patchifier.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.patchifier.proj.weight.device + + def get_pos_ids_by_grid(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + return pos_ids + + def rot_pos_emb(self, grid_thw): + pos_ids = self.get_pos_ids_by_grid(grid_thw) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def compute_attn_mask_seqlen( + self, cu_seqlens: torch.Tensor + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward(self, hidden_states: torch.Tensor, + grid_thw: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.to(self.dtype) + hidden_states = self.patch_embed(hidden_states, grid_thw) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + dtype=grid_thw.dtype + if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + for blk in self.blocks: + hidden_states = blk(hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens) + + if self.post_trunk_norm is not None: + hidden_states = self.post_trunk_norm(hidden_states) + + hidden_states = self.merger(hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2VLMultiModalProcessor, + info=DotsOCRProcessingInfo, + dummy_inputs=DotsOCRDummyInputsBuilder, +) +class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".attn.qkv_proj.": ".attn.qkv.", + ".attn.out_proj.": ".attn.proj.", + }, + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }, + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|img|><|imgpad|><|endofimg|>" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.config: DotsOCRConfig = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + self.multimodal_config = vllm_config.model_config.multimodal_config + + if isinstance(self.config.vision_config, dict): + vision_config = DotsVisionConfig(**self.config.vision_config) + self.config.vision_config = vision_config + else: + vision_config = self.config.vision_config + + self.vision_tower = DotsVisionTransformer( + vision_config, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) + self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=self.config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[DotsOCRImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return DotsOCRImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return DotsOCRImageEmbeddingInputs(type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + def _process_image_input( + self, image_input: DotsOCRImageInputs) -> tuple[torch.Tensor, ...]: + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type( + self.vision_tower.dtype) + else: + pixel_values = image_input["pixel_values"].type( + self.vision_tower.dtype) + image_embeds = self.vision_tower( + pixel_values, grid_thw)[:, :self.config.hidden_size] + + # Split concatenated embeddings for each image item. + merge_size = self.vision_tower.spatial_merge_size + sizes = (torch.tensor(grid_thw_list, dtype=torch.long).prod(-1) // + (merge_size * merge_size)).tolist() + + return image_embeds.split(sizes) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.image_token_id, + ) + + return inputs_embeds + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + elif inputs_embeds is None and kwargs.get("pixel_values") is not None: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + inputs_embeds = None + else: + assert input_ids is not None + inputs_embeds = self.get_multimodal_embeddings( + input_ids, + image_input=image_input, + ) + input_ids = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5dc5d545bb9c..86123bc092b9 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -219,6 +219,7 @@ "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501 "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), + "DotsOCRForCausalLM": ("dots_ocr", "DotsOCRForCausalLM"), "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), # noqa: E501 "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 91bfeb8c55ee..52fa49ad302b 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -9,6 +9,7 @@ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config +from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig from vllm.transformers_utils.configs.eagle import EAGLEConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the @@ -36,6 +37,7 @@ __all__ = [ "ChatGLMConfig", "DeepseekVLV2Config", + "DotsOCRConfig", "EAGLEConfig", "RWConfig", "JAISConfig", diff --git a/vllm/transformers_utils/configs/dotsocr.py b/vllm/transformers_utils/configs/dotsocr.py new file mode 100644 index 000000000000..6bb3c12d9c7e --- /dev/null +++ b/vllm/transformers_utils/configs/dotsocr.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.models.qwen2 import Qwen2Config + + +class DotsVisionConfig(PretrainedConfig): + model_type: str = "dots_vit" + + def __init__( + self, + embed_dim: int = 1536, # vision encoder embed size + hidden_size: int = 1536, # after merger hidden size + intermediate_size: int = 4224, + num_hidden_layers: int = 42, + num_attention_heads: int = 12, + num_channels: int = 3, + patch_size: int = 14, + spatial_merge_size: int = 2, + temporal_patch_size: int = 1, + rms_norm_eps: float = 1e-5, + use_bias: bool = False, + attn_implementation="flash_attention_2", + initializer_range=0.02, + init_merger_std=0.02, + is_causal=False, # ve causal forward + post_norm=True, + gradient_checkpointing=False, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.rms_norm_eps = rms_norm_eps + self.use_bias = use_bias + self.attn_implementation = attn_implementation + self.initializer_range = initializer_range + self.init_merger_std = init_merger_std + self.is_causal = is_causal + self.post_norm = post_norm + self.gradient_checkpointing = gradient_checkpointing + + +class DotsOCRConfig(Qwen2Config): + model_type = "dots_ocr" + + def __init__(self, + image_token_id=151665, + video_token_id=151656, + vision_config: Optional[dict] = None, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_config = DotsVisionConfig(**(vision_config or {})) + + def save_pretrained(self, save_directory, **kwargs): + self._auto_class = None + super().save_pretrained(save_directory, **kwargs) From b608cb4d9a38d3401afce4953543249eec749a38 Mon Sep 17 00:00:00 2001 From: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Date: Mon, 22 Sep 2025 10:49:13 +0800 Subject: [PATCH 14/17] [Docs] GSM8K Accuracy Evaluation doc update (#25360) Signed-off-by: David Chen <530634352@qq.com> --- tests/evals/gsm8k/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/evals/gsm8k/README.md b/tests/evals/gsm8k/README.md index 58572c3a6fbc..29c5199e1e87 100644 --- a/tests/evals/gsm8k/README.md +++ b/tests/evals/gsm8k/README.md @@ -19,7 +19,7 @@ pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \ vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000 # Run evaluation -python tests/gsm8k/gsm8k_eval.py --port 8000 +python tests/evals/gsm8k/gsm8k_eval.py --port 8000 ``` ## Configuration Format From b012cf68f51c646090b5188615daff94903fc3f9 Mon Sep 17 00:00:00 2001 From: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Date: Mon, 22 Sep 2025 11:35:39 +0800 Subject: [PATCH 15/17] [Bugfix] Fix hermes tool parser handling of non-string argument types (#22002) Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: David Chen <530634352@qq.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Chauncey --- .../tool_parsers/test_hermes_tool_parser.py | 131 ++++++++++++++++++ .../openai/tool_parsers/hermes_tool_parser.py | 42 +++++- 2 files changed, 166 insertions(+), 7 deletions(-) diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py index 4bab849f47c2..e0e6b2c07e17 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py @@ -45,8 +45,39 @@ }, }] +PRODUCT_TOOLS = [{ + "type": "function", + "function": { + "name": "get_product_info", + "description": "Get detailed information of a product based on its " + "product ID.", + "parameters": { + "type": "object", + "properties": { + "inserted": { + "type": "boolean", + "description": "inserted.", + }, + "product_id": { + "type": "integer", + "description": "The product ID of the product.", + }, + }, + "required": ["product_id", "inserted"], + }, + }, +}] + MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}] +PRODUCT_MESSAGES = [{ + "role": + "user", + "content": + "Hi! Do you have any detailed information about the product id " + "7355608 and inserted true?" +}] + @pytest.mark.asyncio async def test_non_streaming_tool_call(): @@ -127,3 +158,103 @@ async def test_streaming_tool_call(): print("\n[Streaming Test Passed]") print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") print(f"Reconstructed Arguments: {arguments}") + + +@pytest.mark.asyncio +async def test_non_streaming_product_tool_call(): + """Test tool call integer and boolean parameters in non-streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + response = await client.chat.completions.create( + model=LORA_MODEL, + messages=PRODUCT_MESSAGES, + tools=PRODUCT_TOOLS, + tool_choice="auto", + temperature=0.66, + ) + + assert response.choices + choice = response.choices[0] + message = choice.message + + assert choice.finish_reason == "tool_calls" + assert message.tool_calls is not None + + tool_call = message.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_product_info" + + arguments = json.loads(tool_call.function.arguments) + assert "product_id" in arguments + assert "inserted" in arguments + + product_id = arguments.get("product_id") + inserted = arguments.get("inserted") + + assert isinstance(product_id, int) + assert product_id == 7355608 + assert isinstance(inserted, bool) + assert inserted is True + + print("\n[Non-Streaming Product Test Passed]") + print(f"Tool Call: {tool_call.function.name}") + print(f"Arguments: {arguments}") + + +@pytest.mark.asyncio +async def test_streaming_product_tool_call(): + """Test tool call integer and boolean parameters in streaming mode.""" + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: + client = server.get_async_client() + + stream = await client.chat.completions.create( + model=LORA_MODEL, + messages=PRODUCT_MESSAGES, + tools=PRODUCT_TOOLS, + tool_choice="auto", + temperature=0.66, + stream=True, + ) + + tool_call_chunks = {} + async for chunk in stream: + if not chunk.choices: + continue + + delta = chunk.choices[0].delta + if not delta or not delta.tool_calls: + continue + + for tool_chunk in delta.tool_calls: + index = tool_chunk.index + if index not in tool_call_chunks: + tool_call_chunks[index] = {"name": "", "arguments": ""} + + if tool_chunk.function.name: + tool_call_chunks[index]["name"] += tool_chunk.function.name + if tool_chunk.function.arguments: + tool_call_chunks[index][ + "arguments"] += tool_chunk.function.arguments + + assert len(tool_call_chunks) == 1 + reconstructed_tool_call = tool_call_chunks[0] + + assert reconstructed_tool_call["name"] == "get_product_info" + + arguments = json.loads(reconstructed_tool_call["arguments"]) + assert "product_id" in arguments + assert "inserted" in arguments + + # Handle type coercion for streaming test as well + product_id = arguments.get("product_id") + inserted = arguments.get("inserted") + + assert isinstance(product_id, int) + assert product_id == 7355608 + assert isinstance(inserted, bool) + assert inserted is True + + print("\n[Streaming Product Test Passed]") + print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") + print(f"Reconstructed Arguments: {arguments}") diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index e74c420da1d3..87595953da06 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -368,16 +368,32 @@ def extract_tool_calls_streaming( # case -- we now have the first info about arguments available from # autocompleting the JSON elif cur_arguments and not prev_arguments: + # extract the content after {"name": ..., "arguments": + # directly from tool_call_portion as cur_arguments_json, + # since cur_arguments may differ from the original text + # due to partial JSON parsing + # for example, tool_call_portion = + # {"name": "search", "arguments": {"search_request": {" + # but cur_arguments = + # {"search_request": {}} + function_name = current_tool_call.get("name") + match = re.search( + r'\{"name":\s*"' + + re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)', + tool_call_portion.strip(), re.DOTALL) + if match: + cur_arguments_json = match.group(1) + else: + cur_arguments_json = json.dumps(cur_arguments, + ensure_ascii=False) - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False) logger.debug("finding %s in %s", delta_text, cur_arguments_json) - # get the location where previous args differ from current - if (delta_text not in cur_arguments_json[:-2]): + # get the location where previous args differ from current. + if (delta_text not in cur_arguments_json): return None - args_delta_start_loc = cur_arguments_json[:-2]. \ + args_delta_start_loc = cur_arguments_json. \ rindex(delta_text) + \ len(delta_text) @@ -397,8 +413,20 @@ def extract_tool_calls_streaming( # last case -- we have an update to existing arguments. elif cur_arguments and prev_arguments: - if isinstance(delta_text, str) and len(delta_text.rstrip( - )) >= 1 and delta_text.rstrip()[-1] == '}': + # judge whether the tool_call_portion is a complete JSON + try: + json.loads(tool_call_portion) + is_complete_json = True + except Exception: + is_complete_json = False + + # if the delta_text ends with a '}' and tool_call_portion is a + # complete JSON, then the last '}' does not belong to the + # arguments, so we should trim it off + if isinstance(delta_text, str) \ + and len(delta_text.rstrip()) >= 1 \ + and delta_text.rstrip()[-1] == '}' \ + and is_complete_json: delta_text = delta_text.rstrip()[:-1] logger.debug("got diff %s", delta_text) From 13af5668036d9d251aa5b9b83115193f2ff5d132 Mon Sep 17 00:00:00 2001 From: Juechen Liu Date: Sun, 21 Sep 2025 22:54:48 -0700 Subject: [PATCH 16/17] add unit test Signed-off-by: Juechen Liu --- tests/v1/metrics/test_engine_logger_apis.py | 51 ++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/tests/v1/metrics/test_engine_logger_apis.py b/tests/v1/metrics/test_engine_logger_apis.py index e6a4d0a2a2e8..b350556fc73f 100644 --- a/tests/v1/metrics/test_engine_logger_apis.py +++ b/tests/v1/metrics/test_engine_logger_apis.py @@ -6,6 +6,8 @@ from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger +from vllm.v1.metrics.loggers import LoggingStatLogger +from vllm.v1.metrics.stats import IterationStats class DummyStatLogger: @@ -31,6 +33,15 @@ def log_engine_initialized(self): self.engine_initialized = True +class DummyLoggingStatLogger(LoggingStatLogger): + """ + A dummy logging stat logger for testing purposes. + Implemented the record and log APIs + """ + def get_num_preempted_reqs(self) -> int: + return self.num_preempted_reqs + + @pytest.fixture def log_stats_enabled_engine_args(): """ @@ -62,7 +73,7 @@ async def test_async_llm_replace_default_loggers( @pytest.mark.asyncio async def test_async_llm_add_to_default_loggers(log_stats_enabled_engine_args): """ - It's still possible to use custom stat loggers exclusively by passing + It's still possible to use custom stat loggers exclusively by passing disable_log_stats=True in addition to a list of custom stat loggers. """ # Create engine_args with disable_log_stats=True for this test @@ -81,3 +92,41 @@ async def test_async_llm_add_to_default_loggers(log_stats_enabled_engine_args): assert engine.log_stats engine.shutdown() + + +@pytest.mark.asyncio +async def test_logger_iteration_stats(log_stats_enabled_engine_args): + """ + """ + # Create engine_args with disable_log_stats=True for this test + disabled_log_engine_args = copy.deepcopy(log_stats_enabled_engine_args) + disabled_log_engine_args.disable_log_stats = True + + # Disable default loggers; pass custom stat logger to the constructor + engine = AsyncLLM.from_engine_args(disabled_log_engine_args, + stat_loggers=[DummyLoggingStatLogger]) + + dummy_logger = engine.logger_manager.per_engine_logger_dict[0][0] + + assert len(engine.logger_manager.per_engine_logger_dict[0]) == 1 + assert isinstance(dummy_logger, DummyLoggingStatLogger) + + stats_1 = IterationStats() + stats_1.num_preempted_reqs = 1 + stats_1.num_generation_tokens = 10 + stats_1.num_prompt_tokens = 100 + + stats_2 = IterationStats() + stats_2.num_preempted_reqs = 2 + stats_2.num_generation_tokens = 20 + stats_2.num_prompt_tokens = 200 + + # Expect the record will update the local iteration stats correctly + dummy_logger.record(scheduler_stats=None, iteration_stats=stats_1) + dummy_logger.record(scheduler_stats=None, iteration_stats=stats_2) + + assert dummy_logger.num_preempted_reqs == 3 + assert dummy_logger.num_generation_tokens == 30 + assert dummy_logger.num_prompt_tokens == 300 + + engine.shutdown() From 170ba7d88f228a3a22f15d332c17fd8af884bf89 Mon Sep 17 00:00:00 2001 From: Juechen Liu Date: Sun, 21 Sep 2025 23:32:40 -0700 Subject: [PATCH 17/17] format code Signed-off-by: Juechen Liu --- tests/v1/metrics/test_engine_logger_apis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/v1/metrics/test_engine_logger_apis.py b/tests/v1/metrics/test_engine_logger_apis.py index b350556fc73f..603790f87db5 100644 --- a/tests/v1/metrics/test_engine_logger_apis.py +++ b/tests/v1/metrics/test_engine_logger_apis.py @@ -5,8 +5,8 @@ import pytest from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM -from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger from vllm.v1.metrics.loggers import LoggingStatLogger +from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger from vllm.v1.metrics.stats import IterationStats @@ -38,6 +38,7 @@ class DummyLoggingStatLogger(LoggingStatLogger): A dummy logging stat logger for testing purposes. Implemented the record and log APIs """ + def get_num_preempted_reqs(self) -> int: return self.num_preempted_reqs