Skip to content

Commit 4155a1d

Browse files
rahul-tulixuebwang-amd
authored andcommitted
Add: Support for multiple hidden layers in Eagle3 (vllm-project#26164)
Signed-off-by: Rahul Tuli <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
1 parent 67badb6 commit 4155a1d

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

tests/speculative_decoding/speculators/test_eagle3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
2323
id="qwen3-eagle3-speculator-w4a16-verifier",
2424
),
25+
pytest.param(
26+
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
27+
id="llama3-eagl3-multiple-layers",
28+
),
2529
],
2630
)
2731
def test_eagle3_speculators_model(

vllm/model_executor/models/llama_eagle3.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,20 @@ def __init__(
3434
vllm_config: VllmConfig,
3535
prefix: str = "",
3636
config: Optional[LlamaConfig] = None,
37+
layer_idx: int = 0,
3738
) -> None:
3839
super().__init__(vllm_config, prefix=prefix, config=config)
3940

4041
config = config or vllm_config.model_config.hf_config
4142
quant_config = self.get_quant_config(vllm_config)
4243

44+
# First layer uses 2*hidden_size (embeds + hidden_states concatenated)
45+
# Subsequent layers use hidden_size (only hidden_states, no embeds)
46+
qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size
47+
4348
# override qkv
4449
self.self_attn.qkv_proj = QKVParallelLinear(
45-
2 * self.hidden_size,
50+
qkv_input_size,
4651
self.self_attn.head_dim,
4752
self.self_attn.total_num_heads,
4853
self.self_attn.total_num_kv_heads,
@@ -52,6 +57,7 @@ def __init__(
5257
)
5358

5459
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
60+
self.layer_idx = layer_idx
5561

5662
if getattr(config, "norm_before_residual", False):
5763
self._residual_norm = self._norm_before_residual
@@ -90,11 +96,15 @@ def forward(
9096
hidden_states: torch.Tensor,
9197
residual: Optional[torch.Tensor],
9298
) -> tuple[torch.Tensor, torch.Tensor]:
93-
embeds = self.input_layernorm(embeds)
94-
95-
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
99+
if self.layer_idx == 0:
100+
# First layer: concatenate embeds with hidden_states
101+
embeds = self.input_layernorm(embeds)
102+
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
103+
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
104+
else:
105+
# Subsequent layers: process hidden_states and residuals only
106+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
96107

97-
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
98108
# Self Attention
99109
hidden_states = self.self_attn(
100110
positions=positions,
@@ -133,9 +143,11 @@ def __init__(
133143
[
134144
LlamaDecoderLayer(
135145
current_vllm_config,
136-
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
146+
prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
137147
config=self.config,
148+
layer_idx=layer_idx,
138149
)
150+
for layer_idx in range(self.config.num_hidden_layers)
139151
]
140152
)
141153
if hasattr(self.config, "target_hidden_size"):
@@ -166,13 +178,13 @@ def forward(
166178
assert hidden_states.shape[-1] == input_embeds.shape[-1]
167179

168180
residual = None
169-
hidden_states, residual = self.layers[0](
170-
positions,
171-
input_embeds,
172-
hidden_states,
173-
residual,
174-
)
175-
181+
for layer in self.layers:
182+
hidden_states, residual = layer(
183+
positions=positions,
184+
embeds=input_embeds,
185+
hidden_states=hidden_states,
186+
residual=residual,
187+
)
176188
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
177189
return hidden_states, hidden_prenorm
178190

0 commit comments

Comments
 (0)