diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index ccab04628a16..e7f6b68fc3f7 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -19,6 +19,8 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch model_dir = "meta-llama/Llama-3.1-8B-Instruct" eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" @@ -64,6 +66,86 @@ def _create_proposer( device=current_platform.device_type) +def test_prepare_next_token_ids(): + """ + Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded. + Each will produce a device tensor of next_token_ids, taking as input + either the GPU tensor of sampled_token_ids with -1 for rejected tokens, + or the CPU python list[list[int]] with the rejected tokens removed. + """ + device = torch.device(current_platform.device_type) + + num_requests = 4 + num_speculative_tokens = 4 + batch_spec = BatchSpec( + seq_lens=[num_speculative_tokens + 1] * num_requests, + query_lens=[num_speculative_tokens + 1] * num_requests, + ) + + req_ids = [f"req_{i+1}" for i in range(num_requests)] + mock_input_batch = mock.MagicMock(spec=InputBatch) + mock_input_batch.req_ids = req_ids + mock_input_batch.num_reqs = num_requests + mock_input_batch.vocab_size = 100 + + mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids} + mock_requests = {} + for req_id in req_ids: + mock_request = mock.MagicMock(spec=CachedRequestState) + # Each request will have a backup next token id of 10, 20, 30, 40 + mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10 + mock_request.num_computed_tokens = 0 + mock_requests[req_id] = mock_request + + sampled_token_ids = [ + [0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled + [0, 1, 2, 3, 4], # all accepted, "4" sampled + [-1, -1, -1, -1, -1], # sampling skipped, use backup token "30" + [-1, -1, -1, -1, -1] # this request will be discarded + ] + sampled_token_ids_tensor = torch.tensor(sampled_token_ids, + dtype=torch.int32, + device=device) + sampled_token_ids_cpu = [[i for i in seq if i != -1] + for seq in sampled_token_ids] + + expected_next_token_ids_cpu = [1, 4, 30, 40] + expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu, + dtype=torch.int32, + device=device) + + proposer = _create_proposer("eagle", num_speculative_tokens) + + next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu( + sampled_token_ids_cpu, mock_requests, mock_input_batch, + mock_num_scheduled_tokens) + + assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device) + num_discarded_reqs = 1 + + expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0], + dtype=torch.int32, + device=device) + + next_token_ids_from_padded, valid_sampled_tokens_count = \ + proposer.prepare_next_token_ids_padded( + common_attn_metadata, sampled_token_ids_tensor, mock_requests, + mock_input_batch, discarded_req_indices, num_discarded_reqs) + + assert torch.equal(next_token_ids_from_padded, + expected_next_token_ids_tensor) + assert torch.equal(valid_sampled_tokens_count, + expected_valid_sampled_tokens_count) + + def test_prepare_inputs(): """ cu_target_query_lens: [0, a, a + b, a + b + c] @@ -90,10 +172,24 @@ def test_prepare_inputs(): device=device, ) - # Rejected tokens per request: [1, 3, 2] - num_rejected_tokens = torch.tensor([1, 3, 2], - dtype=torch.int32, - device=device) + # If there are `k` sampled tokens, then `k-1` tokens are draft tokens + # from the previous iteration, and the last token is the bonus token sampled + # from the base model. + num_draft_tokens = [3, 6, 4] # one less than query_lens + # num rejected tokens is [1, 3, 2] + ACCEPT_TOKEN = 0 + BONUS_TOKEN = 1 + REJECT_TOKEN = -1 + sampled_token_ids = [ + [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN], + [ + ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, + REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN + ], + [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN] + ] + sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN] + for seq in sampled_token_ids] # Expected calculations: # query_len_per_req = [4, 7, 5] @@ -125,7 +221,7 @@ def test_prepare_inputs(): proposer = _create_proposer("eagle", 1) updated_metadata, token_indices = proposer.prepare_inputs( - common_attn_metadata, num_rejected_tokens.cpu()) + common_attn_metadata, sampled_token_ids, num_draft_tokens) assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens) @@ -133,6 +229,77 @@ def test_prepare_inputs(): assert torch.equal(token_indices, expected_token_indices) +def test_prepare_inputs_padded(): + """ + Input scenario is 3 requests with num_speculative_tokens == 2 and: + - Request 1: query_len = 3, rejected = 1 + - Request 2: query_len = 3, rejected = 0 + - Request 3: query_len = 3, rejected = 2 + + Expected outputs: + token_indices: [0, 1, 2, + 3, 4, 5, + 6, 7, 8] + Reason: Deferred computation should not disturb the original indices. + + token_indices_to_sample: [1, 5, 6] + Reason: After accounting for rejections, these are the valid token positions + from the original indices to sample from. + """ + + device = torch.device(current_platform.device_type) + + expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], + dtype=torch.int32, + device=device) + expected_token_indices_to_sample = torch.tensor([1, 5, 6], + dtype=torch.int32, + device=device) + + num_speculative_tokens = 2 + batch_spec = BatchSpec( + seq_lens=[3, 3, 3], + query_lens=[3, 3, 3], + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + # Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9] + expected_query_start_loc = torch.tensor([0, 3, 6, 9], + dtype=torch.int32, + device=device) + spec_decode_metadata = SpecDecodeMetadata.make_dummy( + draft_token_ids=[[0] * num_speculative_tokens] * 3, + device=device, + ) + + # num_rejected_tokens = [1, 0, 2] + # num_draft_tokens = [2, 2, 2] + # valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens + valid_sampled_tokens_count = torch.tensor([2, 3, 1], + dtype=torch.int32, + device=device) + + proposer = _create_proposer("eagle", num_speculative_tokens) + + output_metadata, token_indices, token_indices_to_sample = \ + proposer.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count) + + assert output_metadata.max_query_len == 3 + assert torch.equal(output_metadata.query_start_loc, + expected_query_start_loc) + assert torch.equal(token_indices, expected_token_indices) + assert torch.equal(token_indices_to_sample, + expected_token_indices_to_sample) + + @pytest.mark.parametrize("method", ["eagle", "eagle3"]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @@ -373,6 +540,7 @@ def create_deterministic_logits(token_ids): target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, + last_token_indices=None, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata) @@ -526,6 +694,7 @@ def create_deterministic_logits(token_ids, k: int): target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, + last_token_indices=None, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata) assert result.shape == (batch_size, num_speculative_tokens) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index b2d50e385233..fca8c28e5c61 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -83,6 +83,11 @@ class SpeculativeConfig: disable_by_batch_size: Optional[int] = None """Disable speculative decoding for new incoming requests when the number of enqueued requests is larger than this value, if provided.""" + disable_padded_drafter_batch: bool = False + """Disable input padding for speculative decoding. If set to True, + speculative input batches can contain sequences of different lengths, + which may only be supported by certain attention backends. This currently + only affects the EAGLE method of speculation.""" # Ngram proposer configuration prompt_lookup_max: Optional[int] = None diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5154b29405b6..2a178ddf4877 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -27,6 +27,9 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.utils import CpuGpuBuffer +from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.ubatching import dbo_current_ubatch_id logger = init_logger(__name__) @@ -94,20 +97,26 @@ def __init__( dtype=self.dtype, device=device) + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. max_batch_size = vllm_config.scheduler_config.max_num_seqs - self.arange = torch.arange( - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - max_batch_size + 1, - device=device, - dtype=torch.int32, - ) + max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) + self.arange = torch.arange(max_num_slots_for_arange, + device=device, + dtype=torch.int32) self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device) + self.backup_next_token_ids = CpuGpuBuffer( + max_batch_size, + dtype=torch.int32, + pin_memory=is_pin_memory_available(), + device=device, + with_numpy=True) + # Determine allowed attention backends once during initialization. self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] if current_platform.is_rocm(): @@ -156,13 +165,16 @@ def propose( target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, + last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, mm_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) @@ -228,6 +240,12 @@ def propose( last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) + + # Early exit if there is only one draft token to be generated. + if self.num_speculative_tokens == 1: + draft_token_ids = logits.argmax(dim=-1) + return draft_token_ids.view(-1, 1) + positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] @@ -245,15 +263,12 @@ def propose( draft_token_ids = logits.argmax(dim=-1) - # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) - - # TODO: Currently, MTP module released by deepseek only has - # one layer. Adapt this code to support multiple layers once - # there's a multi-layer MTP module. - assert isinstance(attn_metadata, self.allowed_attn_types) + if not isinstance(attn_metadata, self.allowed_attn_types): + raise ValueError( + f"Unsupported attention metadata type for speculative " + "decoding with num_speculative_tokens > 1: " + f"{type(attn_metadata)}. Supported types are: " + f"{self.allowed_attn_types}") # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -263,10 +278,13 @@ def propose( input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size - attn_metadata.num_actual_tokens = batch_size - attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for _ in range(self.num_speculative_tokens - 1): + + common_attn_metadata.num_actual_tokens = batch_size + common_attn_metadata.max_query_len = 1 + common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] + common_attn_metadata.query_start_loc_cpu = torch.from_numpy( + self.token_arange_np[:batch_size + 1]).clone() + for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -286,27 +304,38 @@ def propose( positions) # Increment the sequence lengths. - attn_metadata.max_seq_len += 1 - attn_metadata.seq_lens += 1 - # Consider max model length. - attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - self.max_model_len) + common_attn_metadata.seq_lens += 1 + common_attn_metadata.seq_lens_cpu += 1 # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, + 1) + + common_attn_metadata.num_computed_tokens_cpu = \ + common_attn_metadata.seq_lens_cpu - 1 # Compute the slot mapping. block_numbers = clamped_positions // self.block_size - block_ids = attn_metadata.block_table.gather( + block_ids = common_attn_metadata.block_table_tensor.gather( dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions % self.block_size) + common_attn_metadata.slot_mapping = ( + block_ids * self.block_size + + clamped_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) + common_attn_metadata.slot_mapping.masked_fill_( + exceeds_max_model_len, PADDING_SLOT_ID) + + # Rebuild attention metadata + attn_metadata_builder = \ + self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + attn_metadata = attn_metadata_builder\ + .build_for_drafting(common_attn_metadata=common_attn_metadata, + draft_index=token_index + 1) + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids @@ -347,6 +376,158 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids + def prepare_next_token_ids_cpu( + self, sampled_token_ids: list[list[int]], + requests: dict[str, + CachedRequestState], gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int]) -> torch.Tensor: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids for each request based on the sampled + token ids from the CPU. If a request has no sampled token ids (e.g., + during the initial decoding steps), it falls back to using the request + state to get the next token id. + """ + req_ids = gpu_input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = requests[req_id] + seq_len = (req_state.num_computed_tokens + + num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.input_ids.device) + return next_token_ids + + def prepare_next_token_ids_padded(self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int) -> \ + tuple[torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids and the number of valid sampled tokens + for each request, considering the "discarded" requests whose next token + is not sampled and comes from `request.get_token_id()` instead. + It also accounts for the rejected tokens in `sampled_token_ids`. + This function must use device functions to operate on the inputs, and + should not introduce any blocking CPU-GPU synchronization. + """ + # TODO(Ben): Combine this into a custom fused kernel + + # Precompute get_token_id for when there is no valid next token + num_reqs = gpu_input_batch.num_reqs + self.backup_next_token_ids.np[:num_reqs] = np.array([ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item()) + for i in range(num_reqs) + ]) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + + # Mask out the sampled tokens indices that should not be sampled. + discard_sampled_tokens_req_indices = \ + discard_request_indices[:num_discarded_requests] + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + valid_sampled_token_ids_gpu.index_fill_( + 0, discard_sampled_tokens_req_indices, -1) + + # Generate a mask for all valid tokens within those requests + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, + dtype=torch.bool) + else: + valid_mask = ( + (valid_sampled_token_ids_gpu != -1) & + (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)) + + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, + last_valid_indices_safe.unsqueeze(1)).squeeze(1) + + # Use last token if valid, pre-computed backup if not + batch_size = valid_sampled_token_ids_gpu.shape[0] + next_token_ids = torch.where( + last_valid_indices != -1, selected_tokens, + self.backup_next_token_ids.gpu[:batch_size]) + + return next_token_ids, valid_sampled_tokens_count + + def prepare_inputs_padded(self, + common_attn_metadata: CommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor) -> \ + tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding + It updates the common_attn_metadata for speculative decoding, + but does not consider the rejected tokens. Instead, all tokens + are included as inputs to the speculator, with the rejected tokens + used as padding and filtered out later by `token_indices_to_sample`. + No blocking CPU operations should be introduced in this function. + """ + num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1] + ]) + + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens_gpu)) + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) + + total_num_tokens = query_start_loc_cpu[-1].item() + token_indices = self.arange[:total_num_tokens] + + spec_common_attn_metadata = CommonAttentionMetadata( + query_start_loc=common_attn_metadata.query_start_loc, + seq_lens=common_attn_metadata.seq_lens, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens_cpu=common_attn_metadata.seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + causal=True, + ) + + token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \ + - num_rejected_tokens_gpu + + return spec_common_attn_metadata, token_indices, token_indices_to_sample + def propose_tree( self, batch_size: int, @@ -520,11 +701,11 @@ def propose_tree( def prepare_inputs( self, common_attn_metadata: CommonAttentionMetadata, - # [batch_size] - num_rejected_tokens: torch.Tensor + sampled_token_ids: list[list[int]], + num_draft_tokens: list[int], ) -> tuple[CommonAttentionMetadata, torch.Tensor]: """ - This function is used to prepare the inputs for the spec decode. + This function is used to prepare the inputs for speculative decoding. It updates to the common_attn_metadata to account for the rejected tokens (and newly sampled tokens). It also returns the token indices of the tokens that should be fed to the speculator. @@ -545,6 +726,13 @@ def prepare_inputs( # q1, q1 + 1, ..., q1 + q2 - n2 - 1, # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + num_rejected_tokens = [ + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor(num_rejected_tokens, + dtype=torch.int32) + device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 339b9937b73f..6717622efb80 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -64,7 +64,10 @@ def mm_inputs(self) -> list[MultiModalKwargsItems]: def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: return self.prompt_token_ids[idx] - return self.output_token_ids[idx - self.num_prompt_tokens] + elif idx - self.num_prompt_tokens < len(self.output_token_ids): + return self.output_token_ids[idx - self.num_prompt_tokens] + else: + return -1 class InputBatch: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2ae748dee43c..fc90974a3ad5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -344,6 +344,10 @@ def __init__( self.hidden_size, dtype=self.dtype, numpy=False) + self.discard_request_indices = self._make_buffer(self.max_num_reqs, + dtype=torch.int64) + self.num_discarded_requests = 0 + self.num_draft_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, @@ -977,6 +981,21 @@ def _prepare_inputs( seq_lens = self.seq_lens.gpu[:num_reqs] max_seq_len = self.seq_lens.np[:num_reqs].max().item() + num_tokens = [ + self.requests[r].num_tokens for r in self.input_batch.req_ids + ] + num_tokens_np = np.array(num_tokens, dtype=np.int32) + + # Record the index of requests that should not be sampled, + # so that we could clear the sampled tokens before returning + discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np + discard_request_indices = np.nonzero(discard_requests_mask)[0] + self.num_discarded_requests = len(discard_request_indices) + self.discard_request_indices.np[:self.num_discarded_requests] = ( + discard_request_indices) + + self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) + # Copy the tensors to the GPU. self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens) @@ -1975,23 +1994,12 @@ def _bookkeeping_sync( if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. - discard_sampled_tokens_req_indices = [] - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len < req_state.num_tokens: - # Ignore the sampled token for partial prefills. - # Rewind the generator state as if the token was not sampled. - # This relies on cuda-specific torch-internal impl details - generator = self.input_batch.generators.get(i) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) + discard_sampled_tokens_req_indices = \ + self.discard_request_indices.np[:self.num_discarded_requests] + for i in discard_sampled_tokens_req_indices: + gen = self.input_batch.generators.get(int(i)) + if gen is not None: + gen.set_offset(gen.get_offset() - 4) # Copy some objects so they don't get modified after returning. # This is important when using async scheduling. @@ -2028,10 +2036,10 @@ def _bookkeeping_sync( ) # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + valid_sampled_token_ids[int(i)].clear() else: valid_sampled_token_ids = [] - invalid_req_indices = list(discard_sampled_tokens_req_indices) + invalid_req_indices = discard_sampled_tokens_req_indices.tolist() invalid_req_indices_set = set(invalid_req_indices) assert sampled_token_ids.shape[-1] == 1 @@ -2231,6 +2239,28 @@ def execute_model( with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) + def propose_draft_token_ids(sampled_token_ids): + assert spec_decode_common_attn_metadata is not None + with record_function_or_nullcontext("Draft"): + self._draft_token_ids = self.propose_draft_token_ids( + scheduler_output, + sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + + use_padded_batch_for_eagle = self.speculative_config and \ + self.speculative_config.use_eagle() and \ + not self.speculative_config.disable_padded_drafter_batch + if use_padded_batch_for_eagle: + # EAGLE speculative decoding can use the GPU sampled tokens + # as inputs, and does not need to wait for bookkeeping to finish. + propose_draft_token_ids(sampler_output.sampled_token_ids) + with record_function_or_nullcontext("Bookkeep"): ( num_nans_in_logits, @@ -2244,19 +2274,10 @@ def execute_model( logits, hidden_states, num_scheduled_tokens) - if self.speculative_config: - assert spec_decode_common_attn_metadata is not None - with record_function_or_nullcontext("Draft"): - self._draft_token_ids = self.propose_draft_token_ids( - scheduler_output, - valid_sampled_token_ids, - self.input_batch.sampling_metadata, - hidden_states, - sample_hidden_states, - aux_hidden_states, - spec_decode_metadata, - spec_decode_common_attn_metadata, - ) + if self.speculative_config and not use_padded_batch_for_eagle: + # ngram and other speculative decoding methods use the sampled + # tokens on the CPU, so they are run after bookkeeping. + propose_draft_token_ids(valid_sampled_token_ids) with record_function_or_nullcontext("EPLB"): self.eplb_step() @@ -2296,7 +2317,7 @@ def take_draft_token_ids(self) -> Optional[DraftTokenIds]: def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", - sampled_token_ids: list[list[int]], + sampled_token_ids: Union[torch.Tensor, list[list[int]]], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, @@ -2306,11 +2327,14 @@ def propose_draft_token_ids( ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": + assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) elif self.speculative_config.method == "medusa": + assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) + if sample_hidden_states.shape[0] == len(sampled_token_ids): # The input to the target model does not include draft tokens. hidden_states = sample_hidden_states @@ -2331,27 +2355,37 @@ def propose_draft_token_ids( ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - # TODO(woosuk): Refactor the loop. - req_ids = self.input_batch.req_ids - next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) + + if self.speculative_config.disable_padded_drafter_batch: + # When padded-batch is disabled, the sampled_token_ids should be + # the cpu-side list[list[int]] of valid sampled tokens for each + # request, with invalid requests having empty lists. + assert isinstance(sampled_token_ids, list), \ + "sampled_token_ids should be a python list when" \ + "padded-batch is disabled." + next_token_ids = self.drafter.prepare_next_token_ids_cpu( + sampled_token_ids, self.requests, self.input_batch, + scheduler_output.num_scheduled_tokens) + else: + # When using padded-batch, the sampled_token_ids should be + # the gpu tensor of sampled tokens for each request, of shape + # (num_reqs, num_spec_tokens + 1) with rejected tokens having + # value -1. + assert isinstance(sampled_token_ids, torch.Tensor), \ + "sampled_token_ids should be a torch.Tensor when" \ + "padded-batch is enabled." + next_token_ids, valid_sampled_tokens_count = \ + self.drafter.prepare_next_token_ids_padded( + common_attn_metadata, + sampled_token_ids, + self.requests, + self.input_batch, + self.discard_request_indices.gpu, + self.num_discarded_requests + ) if spec_decode_metadata is None: + token_indices_to_sample = None # input_ids can be None for multimodal models. target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] # TODO(woosuk): Support M-RoPE. @@ -2363,17 +2397,20 @@ def propose_draft_token_ids( else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: - # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, - dtype=torch.int32) - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, num_rejected_tokens_cpu) + if self.speculative_config.disable_padded_drafter_batch: + token_indices_to_sample = None + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens) + else: + common_attn_metadata, token_indices, \ + token_indices_to_sample =\ + self.drafter.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count) target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. @@ -2393,6 +2430,7 @@ def propose_draft_token_ids( target_positions=target_positions, target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, mm_embeds=mm_embeds,