From ea60cdaf7bc6c89d70d4607d0f3d78bb3f637017 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Fri, 3 Oct 2025 12:05:44 +0000 Subject: [PATCH 1/3] Add: Support for multiple layers in eagle3 Signed-off-by: Rahul Tuli --- .../speculators/test_eagle3.py | 35 +++++------- vllm/model_executor/models/llama_eagle3.py | 55 ++++++++++++------- 2 files changed, 50 insertions(+), 40 deletions(-) diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index 5ce6e1593b5c..db9679d6aa6e 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -7,26 +7,21 @@ from vllm.model_executor.models.interfaces import supports_eagle3 -@pytest.mark.parametrize( - "model_path", - [ - pytest.param( - "nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized", - id="llama3-eagle3-speculator", - ), - pytest.param( - "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized", - id="qwen3-eagle3-speculator", - ), - pytest.param( - "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", - id="qwen3-eagle3-speculator-w4a16-verifier", - ), - ], -) -def test_eagle3_speculators_model( - vllm_runner, example_prompts, model_path, monkeypatch -): +@pytest.mark.parametrize("model_path", [ + pytest.param( + "nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized", + id="llama3-eagle3-speculator"), + pytest.param( + "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized", + id="qwen3-eagle3-speculator"), + pytest.param( + "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", + id="qwen3-eagle3-speculator-w4a16-verifier"), + pytest.param("nm-testing/random-weights-llama3.1.8b-2layer-eagle3", + id="llama3-eagl3-multiple-layers"), +]) +def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path, + monkeypatch): """ Test Eagle3 speculators models properly initialize speculative decoding. diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 155a4ecea28f..e15939ed0a48 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -29,20 +29,27 @@ class LlamaDecoderLayer(LlamaDecoderLayer): + def __init__( self, vllm_config: VllmConfig, prefix: str = "", config: Optional[LlamaConfig] = None, + layer_idx: int = 0, ) -> None: super().__init__(vllm_config, prefix=prefix, config=config) config = config or vllm_config.model_config.hf_config quant_config = self.get_quant_config(vllm_config) + # First layer uses 2*hidden_size (embeds + hidden_states concatenated) + # Subsequent layers use hidden_size (only hidden_states, no embeds) + qkv_input_size = (2 * self.hidden_size + if layer_idx == 0 else self.hidden_size) + # override qkv self.self_attn.qkv_proj = QKVParallelLinear( - 2 * self.hidden_size, + qkv_input_size, self.self_attn.head_dim, self.self_attn.total_num_heads, self.self_attn.total_num_kv_heads, @@ -52,6 +59,7 @@ def __init__( ) self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx if getattr(config, "norm_before_residual", False): self._residual_norm = self._norm_before_residual @@ -90,11 +98,18 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - embeds = self.input_layernorm(embeds) - hidden_states, residual = self._residual_norm(hidden_states=hidden_states) + if self.layer_idx == 0: + embeds = self.input_layernorm(embeds) + hidden_states, residual = self._residual_norm( + hidden_states=hidden_states) + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + else: + # Subsequent layers: only process hidden_states + # (no embeds, no input_layernorm) + hidden_states, residual = self._residual_norm( + hidden_states=hidden_states) - hidden_states = torch.cat([embeds, hidden_states], dim=-1) # Self Attention hidden_states = self.self_attn( positions=positions, @@ -129,15 +144,15 @@ def __init__( prefix=maybe_prefix(prefix, "embed_tokens"), ) - self.layers = nn.ModuleList( - [ - LlamaDecoderLayer( - current_vllm_config, - prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), - config=self.config, - ) - ] - ) + self.layers = nn.ModuleList([ + LlamaDecoderLayer(current_vllm_config, + prefix=maybe_prefix( + prefix, + f"layers.{layer_idx + start_layer_id}"), + config=self.config, + layer_idx=layer_idx) + for layer_idx in range(self.config.num_hidden_layers) + ]) if hasattr(self.config, "target_hidden_size"): self.fc = torch.nn.Linear( self.config.target_hidden_size * 3, self.config.hidden_size, bias=False @@ -166,13 +181,13 @@ def forward( assert hidden_states.shape[-1] == input_embeds.shape[-1] residual = None - hidden_states, residual = self.layers[0]( - positions, - input_embeds, - hidden_states, - residual, - ) - + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + embeds=input_embeds, + hidden_states=hidden_states, + residual=residual, + ) hidden_states, hidden_prenorm = self.norm(hidden_states, residual) return hidden_states, hidden_prenorm From b1bb2c498f99d5d78f7458efbac59d541f51f378 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 8 Oct 2025 12:42:02 +0000 Subject: [PATCH 2/3] Use: self.input_layernorm for subsequent layers Signed-off-by: Rahul Tuli --- .../speculators/test_eagle3.py | 39 ++++++++++++------- vllm/model_executor/models/llama_eagle3.py | 33 +++++++--------- 2 files changed, 39 insertions(+), 33 deletions(-) diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index db9679d6aa6e..19ba32d8dee4 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -7,21 +7,30 @@ from vllm.model_executor.models.interfaces import supports_eagle3 -@pytest.mark.parametrize("model_path", [ - pytest.param( - "nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized", - id="llama3-eagle3-speculator"), - pytest.param( - "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized", - id="qwen3-eagle3-speculator"), - pytest.param( - "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", - id="qwen3-eagle3-speculator-w4a16-verifier"), - pytest.param("nm-testing/random-weights-llama3.1.8b-2layer-eagle3", - id="llama3-eagl3-multiple-layers"), -]) -def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path, - monkeypatch): +@pytest.mark.parametrize( + "model_path", + [ + pytest.param( + "nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized", + id="llama3-eagle3-speculator", + ), + pytest.param( + "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized", + id="qwen3-eagle3-speculator", + ), + pytest.param( + "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", + id="qwen3-eagle3-speculator-w4a16-verifier", + ), + pytest.param( + "nm-testing/random-weights-llama3.1.8b-2layer-eagle3", + id="llama3-eagl3-multiple-layers", + ), + ], +) +def test_eagle3_speculators_model( + vllm_runner, example_prompts, model_path, monkeypatch +): """ Test Eagle3 speculators models properly initialize speculative decoding. diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index e15939ed0a48..15f8153250af 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -29,7 +29,6 @@ class LlamaDecoderLayer(LlamaDecoderLayer): - def __init__( self, vllm_config: VllmConfig, @@ -44,8 +43,7 @@ def __init__( # First layer uses 2*hidden_size (embeds + hidden_states concatenated) # Subsequent layers use hidden_size (only hidden_states, no embeds) - qkv_input_size = (2 * self.hidden_size - if layer_idx == 0 else self.hidden_size) + qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size # override qkv self.self_attn.qkv_proj = QKVParallelLinear( @@ -98,17 +96,14 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - if self.layer_idx == 0: embeds = self.input_layernorm(embeds) - hidden_states, residual = self._residual_norm( - hidden_states=hidden_states) + hidden_states, residual = self._residual_norm(hidden_states=hidden_states) hidden_states = torch.cat([embeds, hidden_states], dim=-1) else: # Subsequent layers: only process hidden_states - # (no embeds, no input_layernorm) - hidden_states, residual = self._residual_norm( - hidden_states=hidden_states) + # and residuals + hidden_states, residual = self.input_layernorm(hidden_states, residual) # Self Attention hidden_states = self.self_attn( @@ -144,15 +139,17 @@ def __init__( prefix=maybe_prefix(prefix, "embed_tokens"), ) - self.layers = nn.ModuleList([ - LlamaDecoderLayer(current_vllm_config, - prefix=maybe_prefix( - prefix, - f"layers.{layer_idx + start_layer_id}"), - config=self.config, - layer_idx=layer_idx) - for layer_idx in range(self.config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + current_vllm_config, + prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"), + config=self.config, + layer_idx=layer_idx, + ) + for layer_idx in range(self.config.num_hidden_layers) + ] + ) if hasattr(self.config, "target_hidden_size"): self.fc = torch.nn.Linear( self.config.target_hidden_size * 3, self.config.hidden_size, bias=False From a039040b3bd304fd5e4bb68e05d4e9f553b90b0d Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 8 Oct 2025 12:54:01 +0000 Subject: [PATCH 3/3] Add: better comment Signed-off-by: Rahul Tuli --- vllm/model_executor/models/llama_eagle3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 15f8153250af..71f7274d2d64 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -97,12 +97,12 @@ def forward( residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: if self.layer_idx == 0: + # First layer: concatenate embeds with hidden_states embeds = self.input_layernorm(embeds) hidden_states, residual = self._residual_norm(hidden_states=hidden_states) hidden_states = torch.cat([embeds, hidden_states], dim=-1) else: - # Subsequent layers: only process hidden_states - # and residuals + # Subsequent layers: process hidden_states and residuals only hidden_states, residual = self.input_layernorm(hidden_states, residual) # Self Attention