From 426441b270002d349c612ec5d8d73079491777d7 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 9 Sep 2025 21:35:11 +0000 Subject: [PATCH 01/15] padded, sync-free speculative decoding Signed-off-by: Benjamin Chislett --- vllm/v1/spec_decode/eagle.py | 127 +++++++++++++----- vllm/v1/worker/gpu_input_batch.py | 5 +- vllm/v1/worker/gpu_model_runner.py | 203 ++++++++++++++++++++--------- 3 files changed, 238 insertions(+), 97 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index bf25c91d8390..e67255a17775 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -93,14 +93,13 @@ def __init__( dtype=self.dtype, device=device) + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. max_batch_size = vllm_config.scheduler_config.max_num_seqs - self.arange = torch.arange( - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - max_batch_size + 1, - device=device, - dtype=torch.int32, - ) + max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) + self.arange = torch.arange(max_num_slots_for_arange, + device=device, + dtype=torch.int32) self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), @@ -155,13 +154,16 @@ def propose( target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, + last_token_indices: torch.Tensor | None, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, mm_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) @@ -225,6 +227,12 @@ def propose( last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) + + # Early exit if there is only one draft token to be generated. + if self.num_speculative_tokens == 1: + draft_token_ids = logits.argmax(dim=-1) + return draft_token_ids.view(-1, 1) + positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] @@ -242,15 +250,12 @@ def propose( draft_token_ids = logits.argmax(dim=-1) - # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) - - # TODO: Currently, MTP module released by deepseek only has - # one layer. Adapt this code to support multiple layers once - # there's a multi-layer MTP module. - assert isinstance(attn_metadata, self.allowed_attn_types) + if not isinstance(attn_metadata, self.allowed_attn_types): + raise ValueError( + f"Unsupported attention metadata type for speculative " + "decoding with num_speculative_tokens > 1: " + f"{type(attn_metadata)}. Supported types are: " + f"{self.allowed_attn_types}") # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -260,10 +265,13 @@ def propose( input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size - attn_metadata.num_actual_tokens = batch_size - attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for _ in range(self.num_speculative_tokens - 1): + + common_attn_metadata.num_actual_tokens = batch_size + common_attn_metadata.max_query_len = 1 + common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] + common_attn_metadata.query_start_loc_cpu = torch.from_numpy( + self.token_arange_np[:batch_size + 1]).clone() + for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -283,27 +291,35 @@ def propose( positions) # Increment the sequence lengths. - attn_metadata.max_seq_len += 1 - attn_metadata.seq_lens += 1 - # Consider max model length. - attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - self.max_model_len) + common_attn_metadata.seq_lens += 1 + common_attn_metadata.seq_lens_cpu += 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, + 1) + + common_attn_metadata.num_computed_tokens_cpu = \ + common_attn_metadata.seq_lens_cpu - 1 # Compute the slot mapping. - block_numbers = clamped_positions // self.block_size - block_ids = attn_metadata.block_table.gather( + block_numbers = positions // self.block_size + block_ids = common_attn_metadata.block_table_tensor.gather( dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions % self.block_size) + common_attn_metadata.slot_mapping = (block_ids * self.block_size + + positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) + common_attn_metadata.slot_mapping.masked_fill_( + exceeds_max_model_len, PADDING_SLOT_ID) + + # Rebuild attention metadata + attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ + .build_for_drafting(common_attn_metadata=common_attn_metadata, + draft_index=token_index + 1) + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids @@ -322,12 +338,17 @@ def propose( with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): - last_hidden_states, hidden_states = self.model( + ret_hidden_states = self.model( input_ids=input_ids, positions=self.positions[:input_batch_size], hidden_states=self.hidden_states[:input_batch_size], inputs_embeds=inputs_embeds, ) + if self.method in ("deepseek_mtp", "ernie_mtp"): + last_hidden_states = ret_hidden_states + hidden_states = last_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], None) @@ -338,6 +359,41 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids + def prepare_inputs_deferred(self, + common_attn_metadata: CommonAttentionMetadata): + """ + This function is used to prepare the inputs for the spec decode. + It updates the common_attn_metadata for speculative decoding, + but does not consider the rejected tokens. Instead, all tokens + are included as inputs to the speculator, with the rejected tokens + used as padding and filtered out later by `last_token_indices`. + """ + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) + + total_num_tokens = query_start_loc_cpu[-1].item() + token_indices = self.arange[:total_num_tokens] + + spec_common_attn_metadata = CommonAttentionMetadata( + query_start_loc=common_attn_metadata.query_start_loc, + seq_lens=common_attn_metadata.seq_lens, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens_cpu=common_attn_metadata.seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + causal=True, + ) + + return spec_common_attn_metadata, token_indices + def propose_tree( self, batch_size: int, @@ -537,6 +593,9 @@ def prepare_inputs( device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + # num_rejected_tokens = num_rejected_tokens * 0 + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ - num_rejected_tokens diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index bf9b16575e60..8329950dd09d 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -66,7 +66,10 @@ def mm_inputs(self) -> list[MultiModalKwargsItems]: def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: return self.prompt_token_ids[idx] - return self.output_token_ids[idx - self.num_prompt_tokens] + elif idx - self.num_prompt_tokens < len(self.output_token_ids): + return self.output_token_ids[idx - self.num_prompt_tokens] + else: + return -1 class InputBatch: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 897c3a621320..3741e21345c9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -310,6 +310,10 @@ def __init__( self.hidden_size, dtype=self.dtype, numpy=False) + self.backup_next_token_ids = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) + self.discard_request_indices = self._make_buffer(self.max_num_reqs, + dtype=torch.int64) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -336,6 +340,8 @@ def __init__( self.max_num_tokens), dtype=np.int64) + self.num_discarded_requests = 0 + # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it # means this layer will perform attention using the keys and values @@ -872,6 +878,28 @@ def _prepare_inputs( seq_lens = self.seq_lens.gpu[:num_reqs] max_seq_len = self.seq_lens.np[:num_reqs].max().item() + num_tokens = [ + self.requests[r].num_tokens for r in self.input_batch.req_ids + ] + num_tokens_np = np.array(num_tokens, dtype=np.int32) + + # Record the index of requests that should not be sampled, + # so that we could clear the sampled tokens before returning + discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np + discard_request_indices = np.nonzero(discard_requests_mask)[0] + self.num_discarded_requests = len(discard_request_indices) + self.discard_request_indices.np[:self.num_discarded_requests] = ( + discard_request_indices) + + self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) + + # Precompute get_token_id for when there is no valid next token + self.backup_next_token_ids.np[:num_reqs] = np.array([ + self.requests[self.input_batch.req_ids[i]].get_token_id( + self.seq_lens.np[i]) for i in range(num_reqs) + ]) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + # Copy the tensors to the GPU. self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) @@ -1729,23 +1757,12 @@ def _bookkeeping_sync( if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. - discard_sampled_tokens_req_indices = [] - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len < req_state.num_tokens: - # Ignore the sampled token for partial prefills. - # Rewind the generator state as if the token was not sampled. - # This relies on cuda-specific torch-internal impl details - generator = self.input_batch.generators.get(i) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) + discard_sampled_tokens_req_indices = \ + self.discard_request_indices.np[:self.num_discarded_requests].tolist() + for i in discard_sampled_tokens_req_indices: + gen = self.input_batch.generators.get(int(i)) + if gen is not None: + gen.set_offset(gen.get_offset() - 4) # Copy some objects so they don't get modified after returning. # This is important when using async scheduling. @@ -1956,6 +1973,20 @@ def execute_model( with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) + if self.speculative_config: + assert spec_decode_common_attn_metadata is not None + with record_function_or_nullcontext("Draft"): + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, + sampler_output.sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + with record_function_or_nullcontext("Bookkeep"): assert isinstance(hidden_states, torch.Tensor) ( @@ -1970,19 +2001,19 @@ def execute_model( logits, hidden_states, num_scheduled_tokens) - if self.speculative_config: - assert spec_decode_common_attn_metadata is not None - with record_function_or_nullcontext("Draft"): - self._draft_token_ids = self.propose_draft_token_ids( - scheduler_output, - valid_sampled_token_ids, - self.input_batch.sampling_metadata, - hidden_states, - sample_hidden_states, - aux_hidden_states, - spec_decode_metadata, - spec_decode_common_attn_metadata, - ) + # if self.speculative_config: + # assert spec_decode_common_attn_metadata is not None + # with record_function_or_nullcontext("Draft"): + # self._draft_token_ids = self.propose_draft_token_ids( + # scheduler_output, + # valid_sampled_token_ids, + # self.input_batch.sampling_metadata, + # hidden_states, + # sample_hidden_states, + # aux_hidden_states, + # spec_decode_metadata, + # spec_decode_common_attn_metadata, + # ) with record_function_or_nullcontext("EPLB"): self.eplb_step() @@ -2022,7 +2053,7 @@ def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: list[list[int]], + sampled_token_ids: torch.Tensor | list[list[int]], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, @@ -2032,11 +2063,14 @@ def propose_draft_token_ids( ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": + assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) elif self.speculative_config.method == "medusa": + assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) + if sample_hidden_states.shape[0] == len(sampled_token_ids): # The input to the target model does not include draft tokens. hidden_states = sample_hidden_states @@ -2056,26 +2090,52 @@ def propose_draft_token_ids( sampling_metadata=sampling_metadata, ) elif self.speculative_config.use_eagle(): + assert isinstance(sampled_token_ids, torch.Tensor) assert isinstance(self.drafter, EagleProposer) - # TODO(woosuk): Refactor the loop. - req_ids = self.input_batch.req_ids - next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) + discard_sampled_tokens_req_indices = \ + self.discard_request_indices\ + .gpu[:self.num_discarded_requests] + + _max_gen_len = sampled_token_ids.shape[-1] + # Get all sampled tokens from valid requests + _valid_sampled_token_ids_gpu = sampled_token_ids.clone() + _valid_sampled_token_ids_gpu.index_fill_( + 0, discard_sampled_tokens_req_indices, -1) + # _valid_sampled_token_ids_gpu[ + # discard_sampled_tokens_req_indices, :] = -1 + + # Generate a mask for all valid tokens within those requests + if _max_gen_len == 1: + _valid_mask = torch.ones_like(_valid_sampled_token_ids_gpu, + dtype=torch.bool) + else: + _valid_mask = ((_valid_sampled_token_ids_gpu != -1) & + (_valid_sampled_token_ids_gpu + < self.input_batch.vocab_size)) + + # Count valid tokens in each request + _valid_sampled_count = _valid_mask.sum(dim=1) + + _batch = _valid_sampled_token_ids_gpu.shape[0] + + # Get the rightmost valid index per row + _last_valid_indices = _valid_sampled_count - 1 + + _last_valid_indices_safe = torch.max( + _last_valid_indices, torch.zeros_like(_last_valid_indices)) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + _selected_tokens = torch.gather( + _valid_sampled_token_ids_gpu, 1, + _last_valid_indices_safe.unsqueeze(1)).squeeze(1) + + # Use last token if valid, pre-computed backup if not + next_token_ids_gpu_2 = torch.where( + _last_valid_indices != -1, _selected_tokens, + self.backup_next_token_ids.gpu[:_batch]) + + token_indices_to_sample = None if spec_decode_metadata is None: # input_ids can be None for multimodal models. @@ -2089,17 +2149,35 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: - # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, - dtype=torch.int32) - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, num_rejected_tokens_cpu) + if True: # TODO + _num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1] + ]) + + _num_rejected_tokens_gpu = torch.where( + _num_draft_tokens_gpu > 0, + _num_draft_tokens_gpu + 1 - _valid_sampled_count, + torch.zeros_like(_num_draft_tokens_gpu)) + + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs_deferred(common_attn_metadata) + token_indices_to_sample = \ + common_attn_metadata.query_start_loc[1:] - 1 \ + - _num_rejected_tokens_gpu + else: + # TODO(woosuk): Refactor this. + num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_rejected_tokens = [ + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, + dtype=torch.int32) + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs( + common_attn_metadata, num_rejected_tokens_cpu) target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. @@ -2118,7 +2196,8 @@ def propose_draft_token_ids( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, + next_token_ids=next_token_ids_gpu_2, + last_token_indices=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, mm_embeds=mm_embeds, From 5c9cd1c9b5bc01e3e05266cc22aee1b6e2776471 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 9 Sep 2025 21:48:43 +0000 Subject: [PATCH 02/15] remove dead code Signed-off-by: Benjamin Chislett --- vllm/v1/worker/gpu_model_runner.py | 45 +++++++++++------------------- 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3741e21345c9..35f08cd8941a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2149,35 +2149,22 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: - if True: # TODO - _num_draft_tokens_gpu = torch.cat([ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - - spec_decode_metadata.cu_num_draft_tokens[:-1] - ]) - - _num_rejected_tokens_gpu = torch.where( - _num_draft_tokens_gpu > 0, - _num_draft_tokens_gpu + 1 - _valid_sampled_count, - torch.zeros_like(_num_draft_tokens_gpu)) - - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs_deferred(common_attn_metadata) - token_indices_to_sample = \ - common_attn_metadata.query_start_loc[1:] - 1 \ - - _num_rejected_tokens_gpu - else: - # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, - dtype=torch.int32) - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, num_rejected_tokens_cpu) + _num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1] + ]) + + _num_rejected_tokens_gpu = torch.where( + _num_draft_tokens_gpu > 0, + _num_draft_tokens_gpu + 1 - _valid_sampled_count, + torch.zeros_like(_num_draft_tokens_gpu)) + + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs_deferred(common_attn_metadata) + token_indices_to_sample = \ + common_attn_metadata.query_start_loc[1:] - 1 \ + - _num_rejected_tokens_gpu target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. From 94cf9c5cb418867f9151d2f8318ff49fb0349d04 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 9 Sep 2025 21:51:22 +0000 Subject: [PATCH 03/15] tiny refactor Signed-off-by: Benjamin Chislett --- vllm/v1/spec_decode/eagle.py | 3 --- vllm/v1/worker/gpu_model_runner.py | 3 +-- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index e67255a17775..e2ad9e3ed56e 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -593,9 +593,6 @@ def prepare_inputs( device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - - # num_rejected_tokens = num_rejected_tokens * 0 - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ - num_rejected_tokens diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 35f08cd8941a..9aabb1c9cb55 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -314,6 +314,7 @@ def __init__( dtype=torch.int32) self.discard_request_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int64) + self.num_discarded_requests = 0 # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -340,8 +341,6 @@ def __init__( self.max_num_tokens), dtype=np.int64) - self.num_discarded_requests = 0 - # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it # means this layer will perform attention using the keys and values From 940cb1ab6f7a2a2d42af3bc8ac7dcee2aaa77337 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 9 Sep 2025 21:52:27 +0000 Subject: [PATCH 04/15] remove more dead code Signed-off-by: Benjamin Chislett --- vllm/v1/worker/gpu_model_runner.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9aabb1c9cb55..dd1c7b4c189a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2000,20 +2000,6 @@ def execute_model( logits, hidden_states, num_scheduled_tokens) - # if self.speculative_config: - # assert spec_decode_common_attn_metadata is not None - # with record_function_or_nullcontext("Draft"): - # self._draft_token_ids = self.propose_draft_token_ids( - # scheduler_output, - # valid_sampled_token_ids, - # self.input_batch.sampling_metadata, - # hidden_states, - # sample_hidden_states, - # aux_hidden_states, - # spec_decode_metadata, - # spec_decode_common_attn_metadata, - # ) - with record_function_or_nullcontext("EPLB"): self.eplb_step() From ff408d995cd72a22f5da7588483419df2ed644b1 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 9 Sep 2025 21:55:53 +0000 Subject: [PATCH 05/15] add back support for ngram Signed-off-by: Benjamin Chislett --- vllm/v1/worker/gpu_model_runner.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index dd1c7b4c189a..a677cf6be6fc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1972,7 +1972,9 @@ def execute_model( with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) - if self.speculative_config: + if self.speculative_config and self.speculative_config.use_eagle(): + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. assert spec_decode_common_attn_metadata is not None with record_function_or_nullcontext("Draft"): self._draft_token_ids = self.propose_draft_token_ids( @@ -2000,6 +2002,22 @@ def execute_model( logits, hidden_states, num_scheduled_tokens) + if self.speculative_config and not self.speculative_config.use_eagle(): + # ngram and other speculative decoding methods use the sampled + # tokens on the CPU, so they are run after bookkeeping. + assert spec_decode_common_attn_metadata is not None + with record_function_or_nullcontext("Draft"): + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, + valid_sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + with record_function_or_nullcontext("EPLB"): self.eplb_step() From 6f80fad6326777ab23a59d77771a1441ec9d4665 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 9 Sep 2025 22:04:33 +0000 Subject: [PATCH 06/15] clean up variable names Signed-off-by: Benjamin Chislett --- vllm/v1/worker/gpu_model_runner.py | 64 +++++++++++++++--------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a677cf6be6fc..69fea0b47671 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2095,48 +2095,48 @@ def propose_draft_token_ids( elif self.speculative_config.use_eagle(): assert isinstance(sampled_token_ids, torch.Tensor) assert isinstance(self.drafter, EagleProposer) + + # TODO(Ben): Combine this bookkeeping into a custom fused kernel + + # Mask out the sampled tokens indices that should not be sampled. discard_sampled_tokens_req_indices = \ self.discard_request_indices\ .gpu[:self.num_discarded_requests] - _max_gen_len = sampled_token_ids.shape[-1] - # Get all sampled tokens from valid requests - _valid_sampled_token_ids_gpu = sampled_token_ids.clone() - _valid_sampled_token_ids_gpu.index_fill_( + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + valid_sampled_token_ids_gpu.index_fill_( 0, discard_sampled_tokens_req_indices, -1) - # _valid_sampled_token_ids_gpu[ - # discard_sampled_tokens_req_indices, :] = -1 # Generate a mask for all valid tokens within those requests - if _max_gen_len == 1: - _valid_mask = torch.ones_like(_valid_sampled_token_ids_gpu, - dtype=torch.bool) + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, + dtype=torch.bool) else: - _valid_mask = ((_valid_sampled_token_ids_gpu != -1) & - (_valid_sampled_token_ids_gpu - < self.input_batch.vocab_size)) - - # Count valid tokens in each request - _valid_sampled_count = _valid_mask.sum(dim=1) + valid_mask = ((valid_sampled_token_ids_gpu != -1) & + (valid_sampled_token_ids_gpu + < self.input_batch.vocab_size)) - _batch = _valid_sampled_token_ids_gpu.shape[0] + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) # Get the rightmost valid index per row - _last_valid_indices = _valid_sampled_count - 1 + last_valid_indices = valid_sampled_tokens_count - 1 - _last_valid_indices_safe = torch.max( - _last_valid_indices, torch.zeros_like(_last_valid_indices)) + last_valid_indices_safe = torch.max( + last_valid_indices, torch.zeros_like(last_valid_indices)) # Get last valid token from each row # (assume undefined state where there is no valid token) - _selected_tokens = torch.gather( - _valid_sampled_token_ids_gpu, 1, - _last_valid_indices_safe.unsqueeze(1)).squeeze(1) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, + last_valid_indices_safe.unsqueeze(1)).squeeze(1) # Use last token if valid, pre-computed backup if not - next_token_ids_gpu_2 = torch.where( - _last_valid_indices != -1, _selected_tokens, - self.backup_next_token_ids.gpu[:_batch]) + batch_size = valid_sampled_token_ids_gpu.shape[0] + next_token_ids = torch.where( + last_valid_indices != -1, selected_tokens, + self.backup_next_token_ids.gpu[:batch_size]) token_indices_to_sample = None @@ -2152,22 +2152,22 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: - _num_draft_tokens_gpu = torch.cat([ + num_draft_tokens_gpu = torch.cat([ spec_decode_metadata.cu_num_draft_tokens[0:1], spec_decode_metadata.cu_num_draft_tokens[1:] - spec_decode_metadata.cu_num_draft_tokens[:-1] ]) - _num_rejected_tokens_gpu = torch.where( - _num_draft_tokens_gpu > 0, - _num_draft_tokens_gpu + 1 - _valid_sampled_count, - torch.zeros_like(_num_draft_tokens_gpu)) + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens_gpu)) common_attn_metadata, token_indices =\ self.drafter.prepare_inputs_deferred(common_attn_metadata) token_indices_to_sample = \ common_attn_metadata.query_start_loc[1:] - 1 \ - - _num_rejected_tokens_gpu + - num_rejected_tokens_gpu target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. @@ -2186,7 +2186,7 @@ def propose_draft_token_ids( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids_gpu_2, + next_token_ids=next_token_ids, last_token_indices=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, From 7b74a6956e965950b414c5fc801f14b0f06b8ee9 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 9 Sep 2025 22:39:06 +0000 Subject: [PATCH 07/15] fix pre-commit Signed-off-by: Benjamin Chislett --- vllm/v1/spec_decode/eagle.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index e2ad9e3ed56e..6dc2e4254b59 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -154,7 +154,7 @@ def propose( target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - last_token_indices: torch.Tensor | None, + last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, mm_embeds: Optional[list[torch.Tensor]] = None, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 69fea0b47671..21585ae72843 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2056,7 +2056,7 @@ def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: torch.Tensor | list[list[int]], + sampled_token_ids: Union[torch.Tensor, list[list[int]]], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, From 3a8985382c6513d955fb2393e7fca5d630ebd495 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 9 Sep 2025 22:42:34 +0000 Subject: [PATCH 08/15] use clamped positions for safety Signed-off-by: Benjamin Chislett --- vllm/v1/spec_decode/eagle.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6dc2e4254b59..50c051e4184b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -302,12 +302,13 @@ def propose( common_attn_metadata.seq_lens_cpu - 1 # Compute the slot mapping. - block_numbers = positions // self.block_size + block_numbers = clamped_positions // self.block_size block_ids = common_attn_metadata.block_table_tensor.gather( dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) - common_attn_metadata.slot_mapping = (block_ids * self.block_size + - positions % self.block_size) + common_attn_metadata.slot_mapping = ( + block_ids * self.block_size + + clamped_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. From fe748a6692ef2904d1eb5b6a20384029fb8966d4 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 15 Sep 2025 14:18:58 +0000 Subject: [PATCH 09/15] refactor and feature-flag padded batch Signed-off-by: Benjamin Chislett --- vllm/config/__init__.py | 5 ++ vllm/v1/spec_decode/eagle.py | 113 +++++++++++++++++++++++++++-- vllm/v1/worker/gpu_model_runner.py | 93 +++++++----------------- 3 files changed, 138 insertions(+), 73 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 4f4673ac6e67..b17ca9309250 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2000,6 +2000,11 @@ class SpeculativeConfig: disable_by_batch_size: Optional[int] = None """Disable speculative decoding for new incoming requests when the number of enqueued requests is larger than this value, if provided.""" + disable_padded_batch: bool = False + """Disable input padding for speculative decoding. If set to True, + speculative input batches can contain sequences of different lengths, + which may only be supported by certain attention backends. This currently + only affects the EAGLE method of speculation.""" # Ngram proposer configuration prompt_lookup_max: Optional[int] = None diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 50c051e4184b..4be502db9858 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -27,6 +27,9 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch logger = init_logger(__name__) @@ -106,6 +109,13 @@ def __init__( dtype=self.dtype, device=device) + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True) + # Determine allowed attention backends once during initialization. self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] if current_platform.is_rocm(): @@ -360,14 +370,104 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids + def prepare_next_token_ids(self, + common_attn_metadata: CommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int) -> \ + tuple[torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids and the number of valid sampled tokens + for each request, considering the "discarded" requests whose next token + is not sampled and comes from `request.get_token_id()` instead. + No blocking CPU operations should be introduced in this function. + """ + # TODO(Ben): Combine this into a custom fused kernel + + # Precompute get_token_id for when there is no valid next token + num_reqs = gpu_input_batch.num_reqs + self.backup_next_token_ids.np[:num_reqs] = np.array([ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item()) + for i in range(num_reqs) + ]) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + + # Mask out the sampled tokens indices that should not be sampled. + discard_sampled_tokens_req_indices = \ + discard_request_indices[:num_discarded_requests] + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + valid_sampled_token_ids_gpu.index_fill_( + 0, discard_sampled_tokens_req_indices, -1) + + # Generate a mask for all valid tokens within those requests + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, + dtype=torch.bool) + else: + valid_mask = ( + (valid_sampled_token_ids_gpu != -1) & + (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)) + + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, + last_valid_indices_safe.unsqueeze(1)).squeeze(1) + + # Use last token if valid, pre-computed backup if not + batch_size = valid_sampled_token_ids_gpu.shape[0] + next_token_ids = torch.where( + last_valid_indices != -1, selected_tokens, + self.backup_next_token_ids.gpu[:batch_size]) + + return next_token_ids, valid_sampled_tokens_count + + def prepare_num_rejected_tokens(self, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor) \ + -> torch.Tensor: + """ + Calculate the number of rejected tokens for each request based on the + number of valid sampled tokens and the number of draft tokens in each. + """ + num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1] + ]) + + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens_gpu)) + + return num_rejected_tokens_gpu + def prepare_inputs_deferred(self, - common_attn_metadata: CommonAttentionMetadata): + common_attn_metadata: CommonAttentionMetadata, + num_rejected_tokens: torch.Tensor) -> \ + tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: """ - This function is used to prepare the inputs for the spec decode. + This function is used to prepare the inputs for speculative decoding It updates the common_attn_metadata for speculative decoding, but does not consider the rejected tokens. Instead, all tokens are included as inputs to the speculator, with the rejected tokens - used as padding and filtered out later by `last_token_indices`. + used as padding and filtered out later by `token_indices_to_sample`. + No blocking CPU operations should be introduced in this function. """ query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu @@ -393,7 +493,10 @@ def prepare_inputs_deferred(self, causal=True, ) - return spec_common_attn_metadata, token_indices + token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \ + - num_rejected_tokens + + return spec_common_attn_metadata, token_indices, token_indices_to_sample def propose_tree( self, @@ -571,7 +674,7 @@ def prepare_inputs( num_rejected_tokens: torch.Tensor ) -> tuple[CommonAttentionMetadata, torch.Tensor]: """ - This function is used to prepare the inputs for the spec decode. + This function is used to prepare the inputs for speculative decoding. It updates to the common_attn_metadata to account for the rejected tokens (and newly sampled tokens). It also returns the token indices of the tokens that should be fed to the speculator. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 21585ae72843..223a07fef6ed 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -310,8 +310,6 @@ def __init__( self.hidden_size, dtype=self.dtype, numpy=False) - self.backup_next_token_ids = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) self.discard_request_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int64) self.num_discarded_requests = 0 @@ -892,13 +890,6 @@ def _prepare_inputs( self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) - # Precompute get_token_id for when there is no valid next token - self.backup_next_token_ids.np[:num_reqs] = np.array([ - self.requests[self.input_batch.req_ids[i]].get_token_id( - self.seq_lens.np[i]) for i in range(num_reqs) - ]) - self.backup_next_token_ids.copy_to_gpu(num_reqs) - # Copy the tensors to the GPU. self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) @@ -2096,51 +2087,19 @@ def propose_draft_token_ids( assert isinstance(sampled_token_ids, torch.Tensor) assert isinstance(self.drafter, EagleProposer) - # TODO(Ben): Combine this bookkeeping into a custom fused kernel - - # Mask out the sampled tokens indices that should not be sampled. - discard_sampled_tokens_req_indices = \ - self.discard_request_indices\ - .gpu[:self.num_discarded_requests] - - valid_sampled_token_ids_gpu = sampled_token_ids.clone() - valid_sampled_token_ids_gpu.index_fill_( - 0, discard_sampled_tokens_req_indices, -1) - - # Generate a mask for all valid tokens within those requests - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, - dtype=torch.bool) - else: - valid_mask = ((valid_sampled_token_ids_gpu != -1) & - (valid_sampled_token_ids_gpu - < self.input_batch.vocab_size)) - - # Count the number of valid tokens in each request - valid_sampled_tokens_count = valid_mask.sum(dim=1) - - # Get the rightmost valid index per row - last_valid_indices = valid_sampled_tokens_count - 1 - - last_valid_indices_safe = torch.max( - last_valid_indices, torch.zeros_like(last_valid_indices)) - - # Get last valid token from each row - # (assume undefined state where there is no valid token) - selected_tokens = torch.gather( - valid_sampled_token_ids_gpu, 1, - last_valid_indices_safe.unsqueeze(1)).squeeze(1) - - # Use last token if valid, pre-computed backup if not - batch_size = valid_sampled_token_ids_gpu.shape[0] - next_token_ids = torch.where( - last_valid_indices != -1, selected_tokens, - self.backup_next_token_ids.gpu[:batch_size]) - - token_indices_to_sample = None + next_token_ids, valid_sampled_tokens_count = \ + self.drafter.prepare_next_token_ids( + common_attn_metadata, + spec_decode_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests + ) if spec_decode_metadata is None: + token_indices_to_sample = None # input_ids can be None for multimodal models. target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] # TODO(woosuk): Support M-RoPE. @@ -2152,22 +2111,20 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: - num_draft_tokens_gpu = torch.cat([ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - - spec_decode_metadata.cu_num_draft_tokens[:-1] - ]) - - num_rejected_tokens_gpu = torch.where( - num_draft_tokens_gpu > 0, - num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, - torch.zeros_like(num_draft_tokens_gpu)) - - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs_deferred(common_attn_metadata) - token_indices_to_sample = \ - common_attn_metadata.query_start_loc[1:] - 1 \ - - num_rejected_tokens_gpu + num_rejected_tokens = self.drafter.prepare_num_rejected_tokens( + spec_decode_metadata, valid_sampled_tokens_count) + + if self.speculative_config.disable_padded_batch: + token_indices_to_sample = None + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs( + common_attn_metadata, + num_rejected_tokens.to('cpu').int()) + else: + common_attn_metadata, token_indices, \ + token_indices_to_sample =\ + self.drafter.prepare_inputs_deferred( + common_attn_metadata, num_rejected_tokens) target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. From 8a5bdac2962091c8a0d5054d1c558224a6bf879c Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 15 Sep 2025 14:50:46 +0000 Subject: [PATCH 10/15] add back cpu-side eagle input prep after perf degradation Signed-off-by: Benjamin Chislett --- vllm/v1/spec_decode/eagle.py | 82 ++++++++++++++++++++---------- vllm/v1/worker/gpu_model_runner.py | 44 +++++++++------- 2 files changed, 82 insertions(+), 44 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1cf25f48bc42..07f87fcd9238 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -371,9 +371,40 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids - def prepare_next_token_ids(self, + def prepare_next_token_ids_host( + self, sampled_token_ids: list[list[int]], + requests: dict[str, + CachedRequestState], gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int]) -> torch.Tensor: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids for each request based on the sampled + token ids from the CPU. If a request has no sampled token ids (e.g., + during the initial decoding steps), it falls back to using the request + state to get the next token id. + """ + req_ids = gpu_input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = requests[req_id] + seq_len = (req_state.num_computed_tokens + + num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.input_ids.device) + return next_token_ids + + def prepare_next_token_ids_device(self, common_attn_metadata: CommonAttentionMetadata, - spec_decode_metadata: SpecDecodeMetadata, sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, @@ -385,7 +416,8 @@ def prepare_next_token_ids(self, It calculates the next token ids and the number of valid sampled tokens for each request, considering the "discarded" requests whose next token is not sampled and comes from `request.get_token_id()` instead. - No blocking CPU operations should be introduced in this function. + This function must use device functions to operate on the inputs, and + should not introduce any blocking CPU-GPU synchronization. """ # TODO(Ben): Combine this into a custom fused kernel @@ -437,13 +469,18 @@ def prepare_next_token_ids(self, return next_token_ids, valid_sampled_tokens_count - def prepare_num_rejected_tokens(self, - spec_decode_metadata: SpecDecodeMetadata, - valid_sampled_tokens_count: torch.Tensor) \ - -> torch.Tensor: + def prepare_inputs_deferred(self, + common_attn_metadata: CommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor) -> \ + tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: """ - Calculate the number of rejected tokens for each request based on the - number of valid sampled tokens and the number of draft tokens in each. + This function is used to prepare the inputs for speculative decoding + It updates the common_attn_metadata for speculative decoding, + but does not consider the rejected tokens. Instead, all tokens + are included as inputs to the speculator, with the rejected tokens + used as padding and filtered out later by `token_indices_to_sample`. + No blocking CPU operations should be introduced in this function. """ num_draft_tokens_gpu = torch.cat([ spec_decode_metadata.cu_num_draft_tokens[0:1], @@ -456,20 +493,6 @@ def prepare_num_rejected_tokens(self, num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, torch.zeros_like(num_draft_tokens_gpu)) - return num_rejected_tokens_gpu - - def prepare_inputs_deferred(self, - common_attn_metadata: CommonAttentionMetadata, - num_rejected_tokens: torch.Tensor) -> \ - tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: - """ - This function is used to prepare the inputs for speculative decoding - It updates the common_attn_metadata for speculative decoding, - but does not consider the rejected tokens. Instead, all tokens - are included as inputs to the speculator, with the rejected tokens - used as padding and filtered out later by `token_indices_to_sample`. - No blocking CPU operations should be introduced in this function. - """ query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu new_query_len_per_req = (query_start_loc_cpu[1:] - @@ -495,7 +518,7 @@ def prepare_inputs_deferred(self, ) token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \ - - num_rejected_tokens + - num_rejected_tokens_gpu return spec_common_attn_metadata, token_indices, token_indices_to_sample @@ -671,8 +694,8 @@ def propose_tree( def prepare_inputs( self, common_attn_metadata: CommonAttentionMetadata, - # [batch_size] - num_rejected_tokens: torch.Tensor + sampled_token_ids: list[list[int]], + num_draft_tokens: list[int], ) -> tuple[CommonAttentionMetadata, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding. @@ -696,6 +719,13 @@ def prepare_inputs( # q1, q1 + 1, ..., q1 + q2 - n2 - 1, # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + num_rejected_tokens = [ + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor(num_rejected_tokens, + dtype=torch.int32) + device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7dfc370b82fe..c4170fbe6d3d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2141,7 +2141,10 @@ def execute_model( with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) - if self.speculative_config and self.speculative_config.use_eagle(): + use_padded_batch_for_eagle = self.speculative_config and \ + self.speculative_config.use_eagle() and \ + not self.speculative_config.disable_padded_batch + if use_padded_batch_for_eagle: # EAGLE speculative decoding can use the GPU sampled tokens # as inputs, and does not need to wait for bookkeeping to finish. assert spec_decode_common_attn_metadata is not None @@ -2170,7 +2173,7 @@ def execute_model( logits, hidden_states, num_scheduled_tokens) - if self.speculative_config and not self.speculative_config.use_eagle(): + if self.speculative_config and not use_padded_batch_for_eagle: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. assert spec_decode_common_attn_metadata is not None @@ -2261,19 +2264,24 @@ def propose_draft_token_ids( sampling_metadata=sampling_metadata, ) elif self.speculative_config.use_eagle(): - assert isinstance(sampled_token_ids, torch.Tensor) assert isinstance(self.drafter, EagleProposer) - next_token_ids, valid_sampled_tokens_count = \ - self.drafter.prepare_next_token_ids( - common_attn_metadata, - spec_decode_metadata, - sampled_token_ids, - self.requests, - self.input_batch, - self.discard_request_indices.gpu, - self.num_discarded_requests - ) + if self.speculative_config.disable_padded_batch: + assert isinstance(sampled_token_ids, list) + next_token_ids = self.drafter.prepare_next_token_ids_host( + sampled_token_ids, self.requests, self.input_batch, + scheduler_output.num_scheduled_tokens) + else: + assert isinstance(sampled_token_ids, torch.Tensor) + next_token_ids, valid_sampled_tokens_count = \ + self.drafter.prepare_next_token_ids_device( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests + ) if spec_decode_metadata is None: token_indices_to_sample = None @@ -2288,20 +2296,20 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: - num_rejected_tokens = self.drafter.prepare_num_rejected_tokens( - spec_decode_metadata, valid_sampled_tokens_count) - if self.speculative_config.disable_padded_batch: token_indices_to_sample = None common_attn_metadata, token_indices =\ self.drafter.prepare_inputs( common_attn_metadata, - num_rejected_tokens.to('cpu').int()) + sampled_token_ids, + spec_decode_metadata.num_draft_tokens) else: common_attn_metadata, token_indices, \ token_indices_to_sample =\ self.drafter.prepare_inputs_deferred( - common_attn_metadata, num_rejected_tokens) + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count) target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. From a1e8f73a459140c1fdcb4c8e7c178b7bdac6db3d Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 15 Sep 2025 18:57:42 +0000 Subject: [PATCH 11/15] tweaks for clarity Signed-off-by: Benjamin Chislett --- vllm/v1/spec_decode/eagle.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 19 +++++++++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 07f87fcd9238..da2c423574b6 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -371,7 +371,7 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids - def prepare_next_token_ids_host( + def prepare_next_token_ids_cpu( self, sampled_token_ids: list[list[int]], requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, @@ -403,7 +403,7 @@ def prepare_next_token_ids_host( device=self.input_ids.device) return next_token_ids - def prepare_next_token_ids_device(self, + def prepare_next_token_ids_gpu(self, common_attn_metadata: CommonAttentionMetadata, sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c4170fbe6d3d..a027578871da 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2267,14 +2267,25 @@ def propose_draft_token_ids( assert isinstance(self.drafter, EagleProposer) if self.speculative_config.disable_padded_batch: - assert isinstance(sampled_token_ids, list) - next_token_ids = self.drafter.prepare_next_token_ids_host( + # When padded-batch is disabled, the sampled_token_ids should be + # the cpu-side list[list[int]] of valid sampled tokens for each + # request, with invalid requests having empty lists. + assert isinstance(sampled_token_ids, list), \ + "sampled_token_ids should be a python list when" \ + "padded-batch is disabled." + next_token_ids = self.drafter.prepare_next_token_ids_cpu( sampled_token_ids, self.requests, self.input_batch, scheduler_output.num_scheduled_tokens) else: - assert isinstance(sampled_token_ids, torch.Tensor) + # When using padded-batch, the sampled_token_ids should be + # the gpu tensor of sampled tokens for each request, of shape + # (num_reqs, num_spec_tokens + 1) with rejected tokens having + # value -1. + assert isinstance(sampled_token_ids, torch.Tensor), \ + "sampled_token_ids should be a torch.Tensor when" \ + "padded-batch is enabled." next_token_ids, valid_sampled_tokens_count = \ - self.drafter.prepare_next_token_ids_device( + self.drafter.prepare_next_token_ids_gpu( common_attn_metadata, sampled_token_ids, self.requests, From ef59c6c796e514386ad33cabbad2bbb361a382b4 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 15 Sep 2025 20:46:51 +0000 Subject: [PATCH 12/15] unit test for prepare_input_deferred. more tests coming soon Signed-off-by: Benjamin Chislett --- tests/v1/spec_decode/test_eagle.py | 96 ++++++++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 5 deletions(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index ddedc61aae29..0492dcbed31f 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -19,6 +19,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata model_dir = "meta-llama/Llama-3.1-8B-Instruct" eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" @@ -90,10 +91,24 @@ def test_prepare_inputs(): device=device, ) - # Rejected tokens per request: [1, 3, 2] - num_rejected_tokens = torch.tensor([1, 3, 2], - dtype=torch.int32, - device=device) + # If there are `k` sampled tokens, then `k-1` tokens are draft tokens + # from the previous iteration, and the last token is the bonus token sampled + # from the base model. + num_draft_tokens = [3, 6, 4] # one less than query_lens + # num rejected tokens is [1, 3, 2] + ACCEPT_TOKEN = 0 + BONUS_TOKEN = 1 + REJECT_TOKEN = -1 + sampled_token_ids = [ + [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN], + [ + ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, + REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN + ], + [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN] + ] + sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN] + for seq in sampled_token_ids] # Expected calculations: # query_len_per_req = [4, 7, 5] @@ -125,7 +140,7 @@ def test_prepare_inputs(): proposer = _create_proposer("eagle", 1) updated_metadata, token_indices = proposer.prepare_inputs( - common_attn_metadata, num_rejected_tokens.cpu()) + common_attn_metadata, sampled_token_ids, num_draft_tokens) assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens) @@ -133,6 +148,77 @@ def test_prepare_inputs(): assert torch.equal(token_indices, expected_token_indices) +def test_prepare_inputs_deferred(): + """ + Input scenario is 3 requests with num_speculative_tokens == 2 and: + - Request 1: query_len = 3, rejected = 1 + - Request 2: query_len = 3, rejected = 0 + - Request 3: query_len = 3, rejected = 2 + + Expected outputs: + token_indices: [0, 1, 2, + 3, 4, 5, + 6, 7, 8] + Reason: Deferred computation should not disturb the original indices. + + token_indices_to_sample: [1, 5, 6] + Reason: After accounting for rejections, these are the valid token positions + from the original indices to sample from. + """ + + device = torch.device(current_platform.device_type) + + expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], + dtype=torch.int32, + device=device) + expected_token_indices_to_sample = torch.tensor([1, 5, 6], + dtype=torch.int32, + device=device) + + num_speculative_tokens = 2 + batch_spec = BatchSpec( + seq_lens=[3, 3, 3], + query_lens=[3, 3, 3], + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + # Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9] + expected_query_start_loc = torch.tensor([0, 3, 6, 9], + dtype=torch.int32, + device=device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids=[[0] * num_speculative_tokens] * 3, + device=device, + ) + + # num_rejected_tokens = [1, 0, 2] + # num_draft_tokens = [2, 2, 2] + # valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens + valid_sampled_tokens_count = torch.tensor([2, 3, 1], + dtype=torch.int32, + device=device) + + proposer = _create_proposer("eagle", num_speculative_tokens) + + output_metadata, token_indices, token_indices_to_sample = \ + proposer.prepare_inputs_deferred( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count) + + assert output_metadata.max_query_len == 3 + assert torch.equal(output_metadata.query_start_loc, + expected_query_start_loc) + assert torch.equal(token_indices, expected_token_indices) + assert torch.equal(token_indices_to_sample, + expected_token_indices_to_sample) + + @pytest.mark.parametrize("method", ["eagle", "eagle3"]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) From 71afe35639e738e0a23cc089c95cea70460aa063 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 16 Sep 2025 15:40:40 +0000 Subject: [PATCH 13/15] tweaks Signed-off-by: Benjamin Chislett --- tests/v1/spec_decode/test_eagle.py | 4 +-- vllm/config/__init__.py | 2 +- vllm/v1/spec_decode/eagle.py | 5 ++-- vllm/v1/worker/gpu_model_runner.py | 44 ++++++++++++------------------ 4 files changed, 24 insertions(+), 31 deletions(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 0492dcbed31f..385e98a08405 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -148,7 +148,7 @@ def test_prepare_inputs(): assert torch.equal(token_indices, expected_token_indices) -def test_prepare_inputs_deferred(): +def test_prepare_inputs_padded(): """ Input scenario is 3 requests with num_speculative_tokens == 2 and: - Request 1: query_len = 3, rejected = 1 @@ -206,7 +206,7 @@ def test_prepare_inputs_deferred(): proposer = _create_proposer("eagle", num_speculative_tokens) output_metadata, token_indices, token_indices_to_sample = \ - proposer.prepare_inputs_deferred( + proposer.prepare_inputs_padded( common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 2683255597e2..5ccfbec5ea7e 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1949,7 +1949,7 @@ class SpeculativeConfig: disable_by_batch_size: Optional[int] = None """Disable speculative decoding for new incoming requests when the number of enqueued requests is larger than this value, if provided.""" - disable_padded_batch: bool = False + disable_padded_drafter_batch: bool = False """Disable input padding for speculative decoding. If set to True, speculative input batches can contain sequences of different lengths, which may only be supported by certain attention backends. This currently diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index da2c423574b6..1c71d4b97ddb 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -403,7 +403,7 @@ def prepare_next_token_ids_cpu( device=self.input_ids.device) return next_token_ids - def prepare_next_token_ids_gpu(self, + def prepare_next_token_ids_padded(self, common_attn_metadata: CommonAttentionMetadata, sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], @@ -416,6 +416,7 @@ def prepare_next_token_ids_gpu(self, It calculates the next token ids and the number of valid sampled tokens for each request, considering the "discarded" requests whose next token is not sampled and comes from `request.get_token_id()` instead. + It also accounts for the rejected tokens in `sampled_token_ids`. This function must use device functions to operate on the inputs, and should not introduce any blocking CPU-GPU synchronization. """ @@ -469,7 +470,7 @@ def prepare_next_token_ids_gpu(self, return next_token_ids, valid_sampled_tokens_count - def prepare_inputs_deferred(self, + def prepare_inputs_padded(self, common_attn_metadata: CommonAttentionMetadata, spec_decode_metadata: SpecDecodeMetadata, valid_sampled_tokens_count: torch.Tensor) -> \ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a027578871da..761663d2704b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1910,7 +1910,7 @@ def _bookkeeping_sync( num_nans_in_logits = self._get_nans_in_logits(logits) discard_sampled_tokens_req_indices = \ - self.discard_request_indices.np[:self.num_discarded_requests].tolist() + self.discard_request_indices.np[:self.num_discarded_requests] for i in discard_sampled_tokens_req_indices: gen = self.input_batch.generators.get(int(i)) if gen is not None: @@ -1951,10 +1951,10 @@ def _bookkeeping_sync( ) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids[int(i)].clear() else: valid_sampled_token_ids = [] - invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices = discard_sampled_tokens_req_indices.tolist() invalid_req_indices_set = set(invalid_req_indices) assert sampled_token_ids.shape[-1] == 1 @@ -2141,17 +2141,12 @@ def execute_model( with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) - use_padded_batch_for_eagle = self.speculative_config and \ - self.speculative_config.use_eagle() and \ - not self.speculative_config.disable_padded_batch - if use_padded_batch_for_eagle: - # EAGLE speculative decoding can use the GPU sampled tokens - # as inputs, and does not need to wait for bookkeeping to finish. + def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None with record_function_or_nullcontext("Draft"): self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, - sampler_output.sampled_token_ids, + sampled_token_ids, self.input_batch.sampling_metadata, hidden_states, sample_hidden_states, @@ -2160,6 +2155,14 @@ def execute_model( spec_decode_common_attn_metadata, ) + use_padded_batch_for_eagle = self.speculative_config and \ + self.speculative_config.use_eagle() and \ + not self.speculative_config.disable_padded_drafter_batch + if use_padded_batch_for_eagle: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + propose_draft_token_ids(sampler_output.sampled_token_ids) + with record_function_or_nullcontext("Bookkeep"): ( num_nans_in_logits, @@ -2176,18 +2179,7 @@ def execute_model( if self.speculative_config and not use_padded_batch_for_eagle: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. - assert spec_decode_common_attn_metadata is not None - with record_function_or_nullcontext("Draft"): - self._draft_token_ids = self.propose_draft_token_ids( - scheduler_output, - valid_sampled_token_ids, - self.input_batch.sampling_metadata, - hidden_states, - sample_hidden_states, - aux_hidden_states, - spec_decode_metadata, - spec_decode_common_attn_metadata, - ) + propose_draft_token_ids(valid_sampled_token_ids) with record_function_or_nullcontext("EPLB"): self.eplb_step() @@ -2266,7 +2258,7 @@ def propose_draft_token_ids( elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - if self.speculative_config.disable_padded_batch: + if self.speculative_config.disable_padded_drafter_batch: # When padded-batch is disabled, the sampled_token_ids should be # the cpu-side list[list[int]] of valid sampled tokens for each # request, with invalid requests having empty lists. @@ -2285,7 +2277,7 @@ def propose_draft_token_ids( "sampled_token_ids should be a torch.Tensor when" \ "padded-batch is enabled." next_token_ids, valid_sampled_tokens_count = \ - self.drafter.prepare_next_token_ids_gpu( + self.drafter.prepare_next_token_ids_padded( common_attn_metadata, sampled_token_ids, self.requests, @@ -2307,7 +2299,7 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: - if self.speculative_config.disable_padded_batch: + if self.speculative_config.disable_padded_drafter_batch: token_indices_to_sample = None common_attn_metadata, token_indices =\ self.drafter.prepare_inputs( @@ -2317,7 +2309,7 @@ def propose_draft_token_ids( else: common_attn_metadata, token_indices, \ token_indices_to_sample =\ - self.drafter.prepare_inputs_deferred( + self.drafter.prepare_inputs_padded( common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count) From 52d6a48cc9cec2d410a00f57ed326bae98d6df54 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 16 Sep 2025 16:13:57 +0000 Subject: [PATCH 14/15] tests for new eagle prepare_next_token_ids functions Signed-off-by: Benjamin Chislett --- tests/v1/spec_decode/test_eagle.py | 83 ++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 385e98a08405..59eccc425323 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -20,6 +20,7 @@ from vllm.platforms import current_platform from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch model_dir = "meta-llama/Llama-3.1-8B-Instruct" eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" @@ -65,6 +66,86 @@ def _create_proposer( device=current_platform.device_type) +def test_prepare_next_token_ids(): + """ + Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded. + Each will produce a device tensor of next_token_ids, taking as input + either the GPU tensor of sampled_token_ids with -1 for rejected tokens, + or the CPU python list[list[int]] with the rejected tokens removed. + """ + device = torch.device(current_platform.device_type) + + num_requests = 4 + num_speculative_tokens = 4 + batch_spec = BatchSpec( + seq_lens=[num_speculative_tokens + 1] * num_requests, + query_lens=[num_speculative_tokens + 1] * num_requests, + ) + + req_ids = [f"req_{i+1}" for i in range(num_requests)] + mock_input_batch = mock.MagicMock(spec=InputBatch) + mock_input_batch.req_ids = req_ids + mock_input_batch.num_reqs = num_requests + mock_input_batch.vocab_size = 100 + + mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids} + mock_requests = {} + for req_id in req_ids: + mock_request = mock.MagicMock(spec=CachedRequestState) + # Each request will have a backup next token id of 10, 20, 30, 40 + mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10 + mock_request.num_computed_tokens = 0 + mock_requests[req_id] = mock_request + + sampled_token_ids = [ + [0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled + [0, 1, 2, 3, 4], # all accepted, "4" sampled + [-1, -1, -1, -1, -1], # sampling skipped, use backup token "30" + [-1, -1, -1, -1, -1] # this request will be discarded + ] + sampled_token_ids_tensor = torch.tensor(sampled_token_ids, + dtype=torch.int32, + device=device) + sampled_token_ids_cpu = [[i for i in seq if i != -1] + for seq in sampled_token_ids] + + expected_next_token_ids_cpu = [1, 4, 30, 40] + expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu, + dtype=torch.int32, + device=device) + + proposer = _create_proposer("eagle", num_speculative_tokens) + + next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu( + sampled_token_ids_cpu, mock_requests, mock_input_batch, + mock_num_scheduled_tokens) + + assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device) + num_discarded_reqs = 1 + + expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0], + dtype=torch.int32, + device=device) + + next_token_ids_from_padded, valid_sampled_tokens_count = \ + proposer.prepare_next_token_ids_padded( + common_attn_metadata, sampled_token_ids_tensor, mock_requests, + mock_input_batch, discarded_req_indices, num_discarded_reqs) + + assert torch.equal(next_token_ids_from_padded, + expected_next_token_ids_tensor) + assert torch.equal(valid_sampled_tokens_count, + expected_valid_sampled_tokens_count) + + def test_prepare_inputs(): """ cu_target_query_lens: [0, a, a + b, a + b + c] @@ -457,6 +538,7 @@ def create_deterministic_logits(token_ids): target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, + last_token_indices=None, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata) @@ -608,6 +690,7 @@ def create_deterministic_logits(token_ids, k: int): target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, + last_token_indices=None, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata) assert result.shape == (batch_size, num_speculative_tokens) From fb34a5b96c276bb864dacc8bc1a677daaba8ff93 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Tue, 16 Sep 2025 17:22:03 +0000 Subject: [PATCH 15/15] bugfix Signed-off-by: Benjamin Chislett --- vllm/v1/spec_decode/eagle.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 8a3aa58e72da..2a178ddf4877 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -329,7 +329,9 @@ def propose( exceeds_max_model_len, PADDING_SLOT_ID) # Rebuild attention metadata - attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ + attn_metadata_builder = \ + self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + attn_metadata = attn_metadata_builder\ .build_for_drafting(common_attn_metadata=common_attn_metadata, draft_index=token_index + 1) for layer_name in self.attn_layer_names: