diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 2033e9762ac0..8e2fbf36b4de 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -214,6 +214,7 @@ def recv_kv_caches_and_hidden_states( input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens + num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() hidden_or_intermediate_states_for_one_req = [] @@ -225,9 +226,21 @@ def recv_kv_caches_and_hidden_states( # enumerate different requests # FIXME(Kuntai): This impl assumes that all requests are prefill. for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen + + if start_pos >= num_prefill_tokens: + # This can happen during inflight batching. See: + # vllm/worker/model_runner.py::_prepare_model_input_tensors: + # - input_tokens[:num_prefill_tokens] contains prefill tokens. + # - input_tokens[num_prefill_tokens:] contains decode tokens. + logger.warning("You should set --enable_chunked_prefill=False " + "and --max_num_batched_tokens " + "should be equal to max_seq_len_to_capture") + bypass_model_exec = False + assert start_pos == num_prefill_tokens + break + current_tokens = input_tokens_tensor[start_pos:end_pos] num_tokens = slen @@ -288,7 +301,7 @@ def recv_kv_caches_and_hidden_states( # Here we will fall back to normal model forwarding # But optionally you can adjust model_input so that you only do # prefilling on those tokens that are missing KV caches. - logger.debug( + logger.warning( "[rank%d]: Failed to receive all KVs and hidden " "states, redo model forwarding.", torch.distributed.get_rank()) hidden_or_intermediate_states = None