Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion vllm/v1/core/encoder_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def allocate(self, request: Request, input_id: int) -> None:
def get_cached_input_ids(self, request: Request) -> Set[int]:
return self.cached.get(request.request_id, set())

def free(self, request: Request, input_id: int) -> None:
def free_encoder_input(self, request: Request, input_id: int) -> None:
"""Free a single encoder input id for the request."""
req_id = request.request_id
if req_id not in self.cached:
return
Expand All @@ -49,6 +50,12 @@ def free(self, request: Request, input_id: int) -> None:
self.num_free_slots += request.get_num_encoder_tokens(input_id)
self.freed.append((req_id, input_id))

def free(self, request: Request) -> None:
"""Free all cached input ids for the request."""
input_ids = self.get_cached_input_ids(request)
for input_id in input_ids:
self.free_encoder_input(request, input_id)

def get_freed_ids(self) -> List[Tuple[str, int]]:
freed = self.freed
self.freed = []
Expand Down
14 changes: 8 additions & 6 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def schedule(self) -> "SchedulerOutput":
# which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if num_new_tokens == 0:
# The happens when prompt length is divisible by the block
# This happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last block. Note that we have to re-compute an entire
# block because allocate_slots() assumes num_computed_tokens
Expand Down Expand Up @@ -269,6 +269,7 @@ def schedule(self) -> "SchedulerOutput":

# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = 0
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
Expand Down Expand Up @@ -433,7 +434,8 @@ def update_from_output(
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free(request, input_id)
self.encoder_cache_manager.free_encoder_input(
request, input_id)

if request.num_computed_tokens == request.num_tokens:
req_index = model_runner_output.req_id_to_index[req_id]
Expand All @@ -445,8 +447,10 @@ def update_from_output(
# TODO: Update the KV cache manager for prefix caching.

# Check for stop and update request state.
# This must be called before me make the EngineCoreOutput.
# This must be called before we make the EngineCoreOutput.
stopped = self._check_stop(request)
if stopped:
self._free_request(request)

# Add EngineCoreOutput for this Request.
output = EngineCoreOutput(
Expand All @@ -472,21 +476,18 @@ def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
self._free_request(request)
return True

sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
self._free_request(request)
return True

if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
self._free_request(request)
return True
return False

Expand Down Expand Up @@ -525,6 +526,7 @@ def finish_requests(
def _free_request(self, request: Request) -> None:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.encoder_cache_manager.free(request)
self.running_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
self.finished_req_ids.add(request.request_id)
Expand Down
Loading