Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/simple_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def recv_kv_caches_and_hidden_states(

input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()

hidden_or_intermediate_states_for_one_req = []
Expand All @@ -225,9 +226,21 @@ def recv_kv_caches_and_hidden_states(
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):

start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen

if start_pos >= num_prefill_tokens:
# This can happen during inflight batching. See:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens "
"should be equal to max_seq_len_to_capture")
bypass_model_exec = False
assert start_pos == num_prefill_tokens
break

current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen

Expand Down Expand Up @@ -288,7 +301,7 @@ def recv_kv_caches_and_hidden_states(
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
logger.debug(
logger.warning(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None
Expand Down