diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index 5ce6e1593b5c..19ba32d8dee4 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -22,6 +22,10 @@ "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", id="qwen3-eagle3-speculator-w4a16-verifier", ), + pytest.param( + "nm-testing/random-weights-llama3.1.8b-2layer-eagle3", + id="llama3-eagl3-multiple-layers", + ), ], ) def test_eagle3_speculators_model( diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 155a4ecea28f..71f7274d2d64 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -34,15 +34,20 @@ def __init__( vllm_config: VllmConfig, prefix: str = "", config: Optional[LlamaConfig] = None, + layer_idx: int = 0, ) -> None: super().__init__(vllm_config, prefix=prefix, config=config) config = config or vllm_config.model_config.hf_config quant_config = self.get_quant_config(vllm_config) + # First layer uses 2*hidden_size (embeds + hidden_states concatenated) + # Subsequent layers use hidden_size (only hidden_states, no embeds) + qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size + # override qkv self.self_attn.qkv_proj = QKVParallelLinear( - 2 * self.hidden_size, + qkv_input_size, self.self_attn.head_dim, self.self_attn.total_num_heads, self.self_attn.total_num_kv_heads, @@ -52,6 +57,7 @@ def __init__( ) self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layer_idx = layer_idx if getattr(config, "norm_before_residual", False): self._residual_norm = self._norm_before_residual @@ -90,11 +96,15 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: - embeds = self.input_layernorm(embeds) - - hidden_states, residual = self._residual_norm(hidden_states=hidden_states) + if self.layer_idx == 0: + # First layer: concatenate embeds with hidden_states + embeds = self.input_layernorm(embeds) + hidden_states, residual = self._residual_norm(hidden_states=hidden_states) + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + else: + # Subsequent layers: process hidden_states and residuals only + hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = torch.cat([embeds, hidden_states], dim=-1) # Self Attention hidden_states = self.self_attn( positions=positions, @@ -133,9 +143,11 @@ def __init__( [ LlamaDecoderLayer( current_vllm_config, - prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), + prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"), config=self.config, + layer_idx=layer_idx, ) + for layer_idx in range(self.config.num_hidden_layers) ] ) if hasattr(self.config, "target_hidden_size"): @@ -166,13 +178,13 @@ def forward( assert hidden_states.shape[-1] == input_embeds.shape[-1] residual = None - hidden_states, residual = self.layers[0]( - positions, - input_embeds, - hidden_states, - residual, - ) - + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + embeds=input_embeds, + hidden_states=hidden_states, + residual=residual, + ) hidden_states, hidden_prenorm = self.norm(hidden_states, residual) return hidden_states, hidden_prenorm