Skip to content

Conversation

rahul-tuli
Copy link
Contributor

@rahul-tuli rahul-tuli commented Oct 3, 2025

Summary

Support multiple layers in Eagle3 checkpoints. The dimensions of the first layer should be 2 * hidden_size and those of second (and subsequent layers) should be hidden_size because no input embeddings are present to concatenate for layers after the first.

Changes

  • Updated LlamaDecoderLayer to accept a layer_idx parameter and dynamically set QKV input size based on layer position
  • Modified LlamaModel to support multiple Eagle3 layers via num_hidden_layers config parameter
  • Updated forward pass to only provide input embeddings to the first layer
  • Added smoke test for multi-layer Eagle3 models in test_eagle3.py

Testing

Smoke Test Added

Added parametrized test case llama3-eagl3-multiple-layers using nm-testing/random-weights-llama3.1.8b-2layer-eagle3 to verify multi-layer Eagle3 initialization and execution.

Manual Verification Script
# vllm_run.py
from vllm import LLM, SamplingParams

# Configuration
# multiple-layers-eagle3-drafter
verifier = "meta-llama/Llama-3.1-8B"
drafter = "nm-testing/random-weights-llama3.1.8b-2layer-eagle3"

tensor_parallel_size = 1
gpu_memory_utilization = 0.8

speculative_config = {
    "model": drafter,
    "method": "eagle3",
    "num_speculative_tokens": 3,
}

prompts = [
    "Hello, my name is Alice and I work as a software engineer.",
    "The president of the United States is responsible for leading the executive branch.",
    "The capital of France is Paris, a beautiful city known for its art and culture.",
    "The future of AI is incredibly promising and will transform many aspects of life.",
]

sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=100)

# Initialize LLM with Eagle3
print("Initializing Eagle3 speculative decoding...")
llm = LLM(
    model=verifier,
    tensor_parallel_size=tensor_parallel_size,
    gpu_memory_utilization=gpu_memory_utilization,
    speculative_config=speculative_config,
    disable_log_stats=False,
)

# Generate text
print("Generating text...")
outputs = llm.generate(prompts, sampling_params)

# Print results
print("\n=== Results ===")
for i, output in enumerate(outputs):
    print(f"\nPrompt {i+1}: {output.prompt[:60]}...")
    print(f"Generated: {output.outputs[0].text[:80]}...")

Run command:

VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=6 python vllm_run.py
Verification Output
INFO 09-26 22:05:43 [llm.py:302] Supported_tasks: ['generate']
Generating text...
Adding requests: 100%|███████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 414.63it/s]
Processed prompts: 100%|███████████████████████████| 4/4 [00:01<00:00,  3.72it/s, est. speed input: 59.49 toks/s, output: 371.82 toks/s]

=== Results ===

Prompt 1: Hello, my name is Alice and I work as a software engineer....
Generated:  I started this blog to share knowledge and experiences of my projects, research...

Prompt 2: The president of the United States is responsible for leadin...
Generated:  The president is responsible for overseeing the implementation of the laws that...

Prompt 3: The capital of France is Paris, a beautiful city known for i...
Generated:  Its most famous landmark is the Eiffel Tower, which was built in 1889 and still...

Prompt 4: The future of AI is incredibly promising and will transform ...
Generated:  AI will play an important role in the future of medicine, enabling doctors to p...

Checkpoint for verification: https://huggingface.co/nm-testing/random-weights-llama3.1.8b-2layer-eagle3

@mergify mergify bot added llama Related to Llama models speculative-decoding labels Oct 3, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for multiple hidden layers in Eagle3 speculative decoding. The changes correctly modify the model initialization to create multiple decoder layers and update the forward pass to iterate through them. A new test case is also added to verify the functionality.

My main concern is a critical issue in LlamaDecoderLayer.forward where the residual connection is broken between layers for layer_idx > 0. The residual from the previous layer is ignored, which will likely lead to incorrect model outputs. I've provided a detailed comment and a partial code suggestion to address this.

@rahul-tuli rahul-tuli force-pushed the support-multiple-eagle3-layers branch 5 times, most recently from 451022d to 53614da Compare October 8, 2025 12:21
@rahul-tuli rahul-tuli force-pushed the support-multiple-eagle3-layers branch from 53614da to b1bb2c4 Compare October 8, 2025 12:49
Signed-off-by: Rahul Tuli <[email protected]>
@rahul-tuli
Copy link
Contributor Author

/gemini re-review please!

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for multiple hidden layers in Eagle3, which is a significant enhancement. The changes are logical and well-structured. I've identified one high-severity issue related to memory efficiency that should be addressed.

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, although I am not super familiar with this speculator family atm.
Would you say this is expected behavior (num_layers>2 for eagle speculators) or is this a behavior we should be mentioning in the docs?

Comment on lines +25 to +28
pytest.param(
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
id="llama3-eagl3-multiple-layers",
),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a mock model right? I think we can comment that

@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Oct 8, 2025
@NickLucche
Copy link
Collaborator

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
for param_name, weight_name, shard_id in stacked_params_mapping:

P1 Badge Preserve per-layer names when loading multi-layer Eagle3 weights

The commit now instantiates num_hidden_layers drafter layers, but load_weights still rewrites every checkpoint tensor whose name contains midlayer. to layers.0.. Multi-layer Eagle3 checkpoints necessarily distinguish layers (e.g. midlayer.layers.1.*), and this unconditional replacement collapses all such tensors onto the first module. As a result, only layer 0 receives weights while later layers either fail to load (KeyError in params_dict[name]) or remain randomly initialized. The mapping should retain the original layer index (e.g. midlayer.layers.{i}layers.{i}) instead of hardcoding layers.0..

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@rahul-tuli
Copy link
Contributor Author

LGTM, although I am not super familiar with this speculator family atm. Would you say this is expected behavior (num_layers>2 for eagle speculators) or is this a behavior we should be mentioning in the docs?

As of now by default most eagle3 checkpoints just have a single layer, however there has been interest in trying multiple layers in the eagle3 drafter to see if that improves acceptance rates; our team is training one such speculator already!

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also fix the pre-commit issue so that we can land

@mgoin mgoin enabled auto-merge (squash) October 8, 2025 15:37
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 8, 2025
@mgoin mgoin merged commit cf4cd6c into vllm-project:main Oct 9, 2025
54 checks passed
845473182 pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Oct 10, 2025
…to loader

* 'loader' of https://github.com/dsxsteven/vllm_splitPR: (778 commits)
  [torchao] Add support for ModuleFqnToConfig using regex (vllm-project#26001)
  Add: Support for multiple hidden layers in Eagle3 (vllm-project#26164)
  Enable `RMSNorm` substitution for Transformers backend (vllm-project#26353)
  [Model] Gemma3: Fix GGUF loading and quantization (vllm-project#26189)
  Bump Flashinfer to v0.4.0 (vllm-project#26326)
  Update Dockerfile and install runai-model-streamer[gcs] package (vllm-project#26464)
  [Core] Relax the LoRA  max rank (vllm-project#26461)
  [CI/Build] Fix model nightly tests (vllm-project#26466)
  [Hybrid]: Decouple Kernel Block Size from KV Page Size (vllm-project#24486)
  [Core][KVConnector] Propagate all tokens on resumed preemptions (vllm-project#24926)
  [MM][Doc] Add documentation for configurable mm profiling (vllm-project#26200)
  [Hardware][AMD] Enable FlexAttention backend on ROCm (vllm-project#26439)
  [Bugfix] Incorrect another MM data format in vllm bench throughput (vllm-project#26462)
  [Bugfix] Catch and log invalid token ids in detokenizer #2 (vllm-project#26445)
  [Minor] Change warning->warning_once in preprocess (vllm-project#26455)
  [Bugfix] Set the minimum python version for gpt-oss (vllm-project#26392)
  [Misc] Redact ray runtime env before logging (vllm-project#26302)
  Separate MLAAttention class from Attention (vllm-project#25103)
  [Attention] Register FLASHMLA_SPARSE (vllm-project#26441)
  [Kernels] Modular kernel refactor (vllm-project#24812)
  ...
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants