Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 9 additions & 19 deletions vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def __init__(
from lmcache.integration.vllm.utils import ENGINE_NAME
from lmcache.integration.vllm.vllm_adapter import (
RetrieveStatus, StoreStatus, init_lmcache_engine,
lmcache_retrieve_kv, lmcache_should_store, lmcache_store_kv)
lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store,
lmcache_store_kv)
logger.info("Initializing LMCacheConfig under kv_transfer_config %s",
self.transfer_config)

Expand All @@ -54,6 +55,7 @@ def __init__(
self.cache_config = config.cache_config
self.lmcache_retrieve_kv = lmcache_retrieve_kv
self.lmcache_store_kv = lmcache_store_kv
self.lmcache_should_retrieve = lmcache_should_retrieve
self.lmcache_should_store = lmcache_should_store
self.store_status = StoreStatus
self.retrieve_status = RetrieveStatus
Expand All @@ -65,15 +67,11 @@ def recv_kv_caches_and_hidden_states(
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:

hidden_or_intermediate_states = None

# TODO (Jiayi): Need to support chunked prefill
retrieve_status = self.retrieve_status.PREFILL

model_input, bypass_model_exec = self.lmcache_retrieve_kv(
model_executable, model_input, self.cache_config, kv_caches,
retrieve_status)

retrieve_status = self.lmcache_should_retrieve(model_input)
model_input, bypass_model_exec, hidden_or_intermediate_states =\
self.lmcache_retrieve_kv(
model_executable, model_input, self.cache_config, kv_caches,
retrieve_status)
return hidden_or_intermediate_states, bypass_model_exec, model_input

def send_kv_caches_and_hidden_states(
Expand All @@ -84,15 +82,7 @@ def send_kv_caches_and_hidden_states(
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
num_reqs = 0
seq_group_list = model_input.sampling_metadata.seq_groups
assert seq_group_list is not None
for seq_group in seq_group_list:
seq_ids = seq_group.seq_ids
for seq_id in seq_ids:
num_reqs += 1

# TODO (Jiayi): Only normal prefill is supported for now

store_status = self.lmcache_should_store(model_input)
self.lmcache_store_kv(
self.model_config,
Expand Down