Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/speculative_decoding/speculators/test_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
id="qwen3-eagle3-speculator-w4a16-verifier",
),
pytest.param(
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
id="llama3-eagl3-multiple-layers",
),
Comment on lines +25 to +28
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a mock model right? I think we can comment that

],
)
def test_eagle3_speculators_model(
Expand Down
38 changes: 25 additions & 13 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,20 @@ def __init__(
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[LlamaConfig] = None,
layer_idx: int = 0,
) -> None:
super().__init__(vllm_config, prefix=prefix, config=config)

config = config or vllm_config.model_config.hf_config
quant_config = self.get_quant_config(vllm_config)

# First layer uses 2*hidden_size (embeds + hidden_states concatenated)
# Subsequent layers use hidden_size (only hidden_states, no embeds)
qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size

# override qkv
self.self_attn.qkv_proj = QKVParallelLinear(
2 * self.hidden_size,
qkv_input_size,
self.self_attn.head_dim,
self.self_attn.total_num_heads,
self.self_attn.total_num_kv_heads,
Expand All @@ -52,6 +57,7 @@ def __init__(
)

self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layer_idx = layer_idx

if getattr(config, "norm_before_residual", False):
self._residual_norm = self._norm_before_residual
Expand Down Expand Up @@ -90,11 +96,15 @@ def forward(
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
embeds = self.input_layernorm(embeds)

hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
if self.layer_idx == 0:
# First layer: concatenate embeds with hidden_states
embeds = self.input_layernorm(embeds)
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
else:
# Subsequent layers: process hidden_states and residuals only
hidden_states, residual = self.input_layernorm(hidden_states, residual)

hidden_states = torch.cat([embeds, hidden_states], dim=-1)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
Expand Down Expand Up @@ -133,9 +143,11 @@ def __init__(
[
LlamaDecoderLayer(
current_vllm_config,
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
config=self.config,
layer_idx=layer_idx,
)
for layer_idx in range(self.config.num_hidden_layers)
]
)
if hasattr(self.config, "target_hidden_size"):
Expand Down Expand Up @@ -166,13 +178,13 @@ def forward(
assert hidden_states.shape[-1] == input_embeds.shape[-1]

residual = None
hidden_states, residual = self.layers[0](
positions,
input_embeds,
hidden_states,
residual,
)

for layer in self.layers:
hidden_states, residual = layer(
positions=positions,
embeds=input_embeds,
hidden_states=hidden_states,
residual=residual,
)
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
return hidden_states, hidden_prenorm

Expand Down