Skip to content
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8874e16
[V1] support eagle and eagle3 for qwen2_5vl
Aug 14, 2025
f8af6b8
Merge branch 'vllm-project:main' into Eagle-mulitmodal-support-Qwen2.5vl
LJH-LBJ Aug 15, 2025
b242260
fix bug
Aug 21, 2025
6b682de
support M-RoPE in eagle
Aug 21, 2025
475ce2b
fix bug for SupportsEagle3 and graph compile
Aug 21, 2025
7df533c
fix bug
Aug 21, 2025
fdb7cb2
Merge remote-tracking branch 'origin' into Eagle_support_mrope
Aug 25, 2025
ce18a91
fix bug
Aug 25, 2025
35b8625
Merge branch 'vllm-project:main' into Eagle_support_mrope
LJH-LBJ Aug 25, 2025
bfde2ca
optimize code
Sep 2, 2025
babd305
Merge remote-tracking branch 'upstream/main' into Eagle_support_mrope
Sep 3, 2025
f2c5c19
Merge branch 'vllm-project:main' into Eagle_support_mrope
LJH-LBJ Sep 3, 2025
b0f2181
llama_eagle3 support mm
Sep 3, 2025
9499db4
[llama_eagle3] delete support_torch_compile
Sep 22, 2025
15f24d2
[llama_eagle3] delete get_input_embeddings
Sep 23, 2025
ef44506
[llama_eagle3] fix get_input_embeddings
Sep 23, 2025
55f66b7
[fix bug] delete duplicated code
Sep 23, 2025
57a1b17
Merge branch 'main' into Eagle-mulitmodal-support-Qwen2.5vl
LJH-LBJ Sep 24, 2025
026dde7
fix pre-commit
Sep 24, 2025
fa20a26
fix pre-commit
Sep 24, 2025
50b13f5
fix bug
Sep 25, 2025
ef940da
add benchmark_run_mmstar
Sep 25, 2025
6083933
Merge branch 'Eagle-mulitmodal-support-Qwen2.5vl' of https://github.c…
Sep 25, 2025
ae08501
fix benchmark
Sep 25, 2025
1447be5
fix benchmark
Sep 25, 2025
6e15b6e
fix benchmark
Sep 25, 2025
8cd5a98
fix benchmark
Sep 25, 2025
54fe6b3
Merge branch 'main' into Eagle-mulitmodal-support-Qwen2.5vl
LJH-LBJ Sep 25, 2025
d2bc9e5
fix pre-commit
Sep 25, 2025
50f31df
fix test
Sep 25, 2025
b629dbb
fix bug
Sep 25, 2025
727be44
fix bug
Sep 25, 2025
f8b0651
fix pre-commit
Sep 25, 2025
2133bf3
Merge branch 'main' into Eagle-mulitmodal-support-Qwen2.5vl
LJH-LBJ Sep 26, 2025
0a83ffb
delete benchmarks\benchmarks_run_mmstar.py
Sep 26, 2025
63b726a
fix bug
Sep 26, 2025
23bf946
Merge branch 'main' into Eagle-mulitmodal-support-Qwen2.5vl
LJH-LBJ Sep 26, 2025
6d2882f
opt
Sep 26, 2025
cbe0744
opt
Sep 26, 2025
e5e29e0
opt
Sep 26, 2025
4e24437
fix pre-commit
Sep 26, 2025
c38907b
fix bug
Sep 27, 2025
6096054
Merge branch 'main' into Eagle-mulitmodal-support-Qwen2.5vl
LJH-LBJ Sep 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,9 @@ def check_available_online(
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
"Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo(
"Qwen/Qwen2.5-VL-7B-Instruct",
speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"),
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.3"),
}
Expand Down
8 changes: 5 additions & 3 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def test_ngram_correctness(

@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct",
"Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1), False),
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
Expand All @@ -145,9 +147,9 @@ def test_ngram_correctness(
"eagle618/eagle-deepseek-v3-random", 1), False),
],
ids=[
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
"llama4_eagle", "llama4_eagle_mm",
"deepseek_eagle"
"qwen3_eagle3", "qwen2_5_vl_eagle3",
"llama3_eagle", "llama3_eagle3", "llama4_eagle",
"llama4_eagle_mm", "deepseek_eagle"
])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
Expand Down
80 changes: 80 additions & 0 deletions vllm/benchmarks/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
):
dataset_class = MLPerfDataset
args.hf_split = "train"
elif (
args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS
or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS
):
dataset_class = MMStarDataset
args.hf_split = "val"
args.hf_subset = None
else:
supported_datasets = set([
dataset_name for cls in HuggingFaceDataset.__subclasses__()
Expand Down Expand Up @@ -2721,3 +2728,76 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]:

random.shuffle(requests)
return requests


# -----------------------------------------------------------------------------
# MMStar Dataset Implementation
# -----------------------------------------------------------------------------


class MMStarDataset(HuggingFaceDataset):
"""
Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar
refer to: https://github.com/sgl-project/SpecForge/pull/106
"""
DEFAULT_OUTPUT_LEN = 128
SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"}
IS_MULTIMODAL = True

def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
request_id_prefix: str = "",
no_oversample: bool = False,
**kwargs,
) -> list[SampleRequest]:
# If --hf-output-len is not set, use the default output length.
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests: list[SampleRequest] = []

for ind, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
# Split the question text from options
# (keep only the part before "Options:").
full_q: str = item.get("question", "")
question_text = full_q.split("Options:", 1)[0].strip()

# Multimodal image content.
mm_content = process_image(item["image"])

# Compute prompt token length (note: this is plain text length
# if enable_multimodal_chat is False).
prompt_len = len(tokenizer(question_text).input_ids)

if enable_multimodal_chat:
# If multimodal content should be embedded in the chat message,
# convert to [{"role":"user","content":[...]}]
prompt = self.apply_multimodal_chat_transformation(
question_text, mm_content
)
mm_for_request = None # Already embedded in chat content.
else:
# Default: prompt is plain text,
# image is in mm_content for the bench to assemble.
prompt = question_text
mm_for_request = mm_content

sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_for_request,
request_id=request_id_prefix + str(ind),
)
)

self.maybe_oversample_requests(
sampled_requests, num_requests, request_id_prefix, no_oversample
)
return sampled_requests
27 changes: 19 additions & 8 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.nn as nn
from transformers import LlamaConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
Expand All @@ -19,6 +18,7 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
LlamaForCausalLM)

Expand Down Expand Up @@ -102,7 +102,6 @@ def forward(
return hidden_states, residual


@support_torch_compile
class LlamaModel(nn.Module):

def __init__(
Expand Down Expand Up @@ -145,13 +144,21 @@ def __init__(
eps=self.config.rms_norm_eps,
)

def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)
if input_embeds is None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like input_embeds is missing an 's' in the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I cannot understand. In L163, there is an 's' in input_embeds . @ggg-s

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I cannot understand. In L163, there is an 's' in input_embeds . @ggg-s
The error happened because the dynamic_arg_dims dictionary in @support_torch_compile used the key inputs_embeds while your forward function actually defined the argument as input_embeds, so TorchDynamo could not find a matching parameter and raised an error like ValueError: dynamic_arg_dims specifies 'inputs_embeds' but forward() has no such argument; although fixing the naming mismatch removes that specific error, vLLM still fails to start because the dynamic dimension annotations themselves cause further shape-guard conflicts, and in practice the service only starts normally once the entire @support_torch_compile(dynamic_arg_dims=...) decorator is removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining. I'd better remove support_torch_compile

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LJH-LBJ removal of torch compile is having performance impact downstream, in particular in #25109. Do we know of a workaround for this issue?

input_embeds = self.get_input_embeddings(input_ids)
assert hidden_states.shape[-1] == input_embeds.shape[-1]

residual = None
Expand Down Expand Up @@ -239,11 +246,7 @@ def forward(
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states)
return self.model(input_ids, positions, hidden_states, inputs_embeds)

def compute_logits(
self,
Expand Down Expand Up @@ -299,3 +302,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
skip_substrs=skip_substrs,
)
loader.load_weights(model_weights.items())

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
return inputs_embeds
11 changes: 9 additions & 2 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from vllm.utils import is_pin_memory_available
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from .interfaces import (MultiModalEmbeddings, SupportsEagle3, SupportsLoRA,
SupportsMultiModal, SupportsMultiModalPruning,
SupportsPP, SupportsQuant)
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
Expand Down Expand Up @@ -965,7 +965,7 @@ def get_replacement_qwen2vl(item_idx: int, modality: str):
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP,
SupportsQuant,
SupportsQuant, SupportsEagle3,
SupportsMultiModalPruning):

packed_modules_mapping = {
Expand Down Expand Up @@ -1028,6 +1028,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)

def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
Expand Down
Loading
Loading