1212from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
1313from vllm .v1 .sample .metadata import SamplingMetadata
1414
15+ PADDING_SLOT_ID = - 1
16+
1517
1618class EagleProposer :
1719
@@ -23,6 +25,7 @@ def __init__(
2325 self .vllm_config = vllm_config
2426 self .num_speculative_tokens = (
2527 vllm_config .speculative_config .num_speculative_tokens )
28+ self .max_model_len = vllm_config .model_config .max_model_len
2629 self .block_size = vllm_config .cache_config .block_size
2730 # We need +1 here because the arange is used to set query_start_loc,
2831 # which has one more element than batch_size.
@@ -112,22 +115,48 @@ def propose(
112115 # Update the inputs.
113116 input_ids = draft_token_ids_list [- 1 ]
114117 positions += 1
118+
119+ # NOTE(woosuk): We should handle the case where the draft model
120+ # generates tokens beyond the max model length. Since it is complex
121+ # to remove such requests from the batch, we keep them in the batch
122+ # but adjust the position ids and slot mappings to avoid the
123+ # out-of-range access during the model execution. The draft tokens
124+ # generated with this adjustment should be ignored.
125+ exceeds_max_model_len = positions >= self .max_model_len
126+ # Mask out the position ids that exceed the max model length.
127+ # Otherwise, we may get out-of-range error in RoPE.
128+ clamped_positions = torch .where (exceeds_max_model_len , 0 ,
129+ positions )
130+
131+ # Increment the sequence lengths.
115132 attn_metadata .max_seq_len += 1
116133 attn_metadata .seq_lens += 1
134+ # Consider max model length.
135+ attn_metadata .max_seq_len = min (attn_metadata .max_seq_len ,
136+ self .max_model_len )
137+ # For the requests that exceed the max model length, we set the
138+ # sequence length to 1 to minimize their overheads in attention.
139+ attn_metadata .seq_lens .masked_fill_ (exceeds_max_model_len , 1 )
140+
117141 # Compute the slot mapping.
118- block_numbers = positions // self .block_size
142+ block_numbers = clamped_positions // self .block_size
119143 block_ids = block_table .gather (dim = 1 ,
120144 index = block_numbers .view (- 1 , 1 ))
121145 block_ids = block_ids .view (- 1 )
122146 attn_metadata .slot_mapping = (block_ids * self .block_size +
123- positions % self .block_size )
147+ clamped_positions % self .block_size )
148+ # Mask out the slot mappings that exceed the max model length.
149+ # Otherwise, the KV cache will be inadvertently updated with the
150+ # padding tokens.
151+ attn_metadata .slot_mapping .masked_fill_ (exceeds_max_model_len ,
152+ PADDING_SLOT_ID )
124153
125154 # Run the model.
126155 with set_forward_context (attn_metadata , self .vllm_config ):
127156 hidden_states = self .model (
128157 input_ids = input_ids ,
129158 hidden_states = hidden_states ,
130- positions = positions ,
159+ positions = clamped_positions ,
131160 )
132161 logits = self .model .compute_logits (hidden_states , None )
133162 draft_token_ids , probs = compute_probs_and_sample_next_token (
0 commit comments