Skip to content

Commit 451022d

Browse files
committed
Add: Support for multiple layers in eagle3
Signed-off-by: Rahul Tuli <[email protected]>
1 parent fc67969 commit 451022d

File tree

2 files changed

+50
-40
lines changed

2 files changed

+50
-40
lines changed

tests/speculative_decoding/speculators/test_eagle3.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,21 @@
77
from vllm.model_executor.models.interfaces import supports_eagle3
88

99

10-
@pytest.mark.parametrize(
11-
"model_path",
12-
[
13-
pytest.param(
14-
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized",
15-
id="llama3-eagle3-speculator",
16-
),
17-
pytest.param(
18-
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
19-
id="qwen3-eagle3-speculator",
20-
),
21-
pytest.param(
22-
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
23-
id="qwen3-eagle3-speculator-w4a16-verifier",
24-
),
25-
],
26-
)
27-
def test_eagle3_speculators_model(
28-
vllm_runner, example_prompts, model_path, monkeypatch
29-
):
10+
@pytest.mark.parametrize("model_path", [
11+
pytest.param(
12+
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized",
13+
id="llama3-eagle3-speculator"),
14+
pytest.param(
15+
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
16+
id="qwen3-eagle3-speculator"),
17+
pytest.param(
18+
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
19+
id="qwen3-eagle3-speculator-w4a16-verifier"),
20+
pytest.param("nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
21+
id="llama3-eagl3-multiple-layers"),
22+
])
23+
def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path,
24+
monkeypatch):
3025
"""
3126
Test Eagle3 speculators models properly initialize speculative decoding.
3227

vllm/model_executor/models/llama_eagle3.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,27 @@
2828

2929

3030
class LlamaDecoderLayer(LlamaDecoderLayer):
31+
3132
def __init__(
3233
self,
3334
vllm_config: VllmConfig,
3435
prefix: str = "",
3536
config: Optional[LlamaConfig] = None,
37+
layer_idx: int = 0,
3638
) -> None:
3739
super().__init__(vllm_config, prefix=prefix, config=config)
3840

3941
config = config or vllm_config.model_config.hf_config
4042
quant_config = self.get_quant_config(vllm_config)
4143

44+
# First layer uses 2*hidden_size (embeds + hidden_states concatenated)
45+
# Subsequent layers use hidden_size (only hidden_states, no embeds)
46+
qkv_input_size = (2 * self.hidden_size
47+
if layer_idx == 0 else self.hidden_size)
48+
4249
# override qkv
4350
self.self_attn.qkv_proj = QKVParallelLinear(
44-
2 * self.hidden_size,
51+
qkv_input_size,
4552
self.self_attn.head_dim,
4653
self.self_attn.total_num_heads,
4754
self.self_attn.total_num_kv_heads,
@@ -51,6 +58,7 @@ def __init__(
5158
)
5259

5360
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
61+
self.layer_idx = layer_idx
5462

5563
if getattr(config, "norm_before_residual", False):
5664
self._residual_norm = self._norm_before_residual
@@ -89,11 +97,18 @@ def forward(
8997
hidden_states: torch.Tensor,
9098
residual: Optional[torch.Tensor],
9199
) -> tuple[torch.Tensor, torch.Tensor]:
92-
embeds = self.input_layernorm(embeds)
93100

94-
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
101+
if self.layer_idx == 0:
102+
embeds = self.input_layernorm(embeds)
103+
hidden_states, residual = self._residual_norm(
104+
hidden_states=hidden_states)
105+
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
106+
else:
107+
# Subsequent layers: only process hidden_states
108+
# (no embeds, no input_layernorm)
109+
hidden_states, residual = self._residual_norm(
110+
hidden_states=hidden_states)
95111

96-
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
97112
# Self Attention
98113
hidden_states = self.self_attn(
99114
positions=positions,
@@ -128,15 +143,15 @@ def __init__(
128143
prefix=maybe_prefix(prefix, "embed_tokens"),
129144
)
130145

131-
self.layers = nn.ModuleList(
132-
[
133-
LlamaDecoderLayer(
134-
current_vllm_config,
135-
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
136-
config=self.config,
137-
)
138-
]
139-
)
146+
self.layers = nn.ModuleList([
147+
LlamaDecoderLayer(current_vllm_config,
148+
prefix=maybe_prefix(
149+
prefix,
150+
f"layers.{layer_idx + start_layer_id}"),
151+
config=self.config,
152+
layer_idx=layer_idx)
153+
for layer_idx in range(self.config.num_hidden_layers)
154+
])
140155
if hasattr(self.config, "target_hidden_size"):
141156
self.fc = torch.nn.Linear(
142157
self.config.target_hidden_size * 3, self.config.hidden_size, bias=False
@@ -165,13 +180,13 @@ def forward(
165180
assert hidden_states.shape[-1] == input_embeds.shape[-1]
166181

167182
residual = None
168-
hidden_states, residual = self.layers[0](
169-
positions,
170-
input_embeds,
171-
hidden_states,
172-
residual,
173-
)
174-
183+
for layer in self.layers:
184+
hidden_states, residual = layer(
185+
positions=positions,
186+
embeds=input_embeds,
187+
hidden_states=hidden_states,
188+
residual=residual,
189+
)
175190
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
176191
return hidden_states, hidden_prenorm
177192

0 commit comments

Comments
 (0)