Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,10 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate()

def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for specified or all devices."""
if device:
return self._allocators[device].reset_prefix_cache()
success = True
for allocator in self._allocators.values():
success = success and allocator.reset_prefix_cache()
Expand Down
2 changes: 1 addition & 1 deletion vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
pass

@abstractmethod
def reset_prefix_cache(self) -> bool:
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache."""
pass

Expand Down
4 changes: 2 additions & 2 deletions vllm/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,8 +456,8 @@ def get_num_free_cpu_blocks(self) -> int:
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)

def reset_prefix_cache(self) -> bool:
return self.block_allocator.reset_prefix_cache()
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.block_allocator.reset_prefix_cache(device)

def _can_swap(self,
seq_group: SequenceGroup,
Expand Down
6 changes: 3 additions & 3 deletions vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import enum
from abc import ABC, abstractmethod
from typing import List
from typing import List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple

Expand Down Expand Up @@ -125,8 +125,8 @@ def get_prefix_cache_hit_rate(self, device: Device) -> float:
pass

@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache for all devices."""
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for specified or all devices."""
pass

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions vllm/core/placeholder_block_space_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import List, Tuple
from typing import List, Optional, Tuple

from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup
Expand Down Expand Up @@ -92,7 +92,7 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup,
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1

def reset_prefix_cache(self) -> bool:
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return True

def get_num_cached_tokens(self, seq: Sequence) -> int:
Expand Down
4 changes: 2 additions & 2 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,8 @@ def has_unfinished_seqs(self) -> bool:
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)

def reset_prefix_cache(self) -> bool:
return self.block_manager.reset_prefix_cache()
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.block_manager.reset_prefix_cache(device)

def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
Expand Down
7 changes: 4 additions & 3 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import deprecate_kwargs, weak_bind
from vllm.utils import Device, deprecate_kwargs, weak_bind

logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
Expand Down Expand Up @@ -1216,8 +1216,9 @@ async def start_profile(self) -> None:
async def stop_profile(self) -> None:
self.engine.stop_profile()

async def reset_prefix_cache(self) -> None:
self.engine.reset_prefix_cache()
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
self.engine.reset_prefix_cache(device)

async def sleep(self, level: int = 1) -> None:
self.engine.sleep(level)
Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,12 +955,12 @@ def has_unfinished_requests_for_virtual_engine(
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()

def reset_prefix_cache(self) -> bool:
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices."""

success = True
for scheduler in self.scheduler:
success = success and scheduler.reset_prefix_cache()
success = success and scheduler.reset_prefix_cache(device)
return success

@staticmethod
Expand Down
7 changes: 4 additions & 3 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.utils import deprecate_kwargs
from vllm.utils import Device, deprecate_kwargs

VLLM_RPC_SUCCESS_STR = "SUCCESS"

Expand Down Expand Up @@ -123,8 +123,9 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2


class RPCResetPrefixCacheRequest(Enum):
RESET_PREFIX_CACHE = 1
@dataclass
class RPCResetPrefixCacheRequest:
device: Device


class RPCSleepRequest(Enum):
Expand Down
7 changes: 4 additions & 3 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
from vllm.utils import Device, deprecate_kwargs

logger = init_logger(__name__)

Expand Down Expand Up @@ -684,11 +684,12 @@ async def stop_profile(self) -> None:
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)

async def reset_prefix_cache(self) -> None:
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache"""

await self._send_one_way_rpc_request(
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
request=RPCResetPrefixCacheRequest(device),
socket=self.input_socket)

async def sleep(self, level: int = 1) -> None:
Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import collect_from_async_generator, random_uuid
from vllm.utils import Device, collect_from_async_generator, random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -274,7 +274,8 @@ async def stop_profile(self) -> None:
...

@abstractmethod
async def reset_prefix_cache(self) -> None:
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
"""Reset the prefix cache"""
...

Expand Down
7 changes: 4 additions & 3 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
is_list_of)

logger = init_logger(__name__)

Expand Down Expand Up @@ -1187,8 +1188,8 @@ def start_profile(self) -> None:
def stop_profile(self) -> None:
self.llm_engine.stop_profile()

def reset_prefix_cache(self) -> bool:
return self.llm_engine.reset_prefix_cache()
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.llm_engine.reset_prefix_cache(device)

def sleep(self, level: int = 1):
"""
Expand Down
10 changes: 7 additions & 3 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit)
from vllm.version import __version__ as VLLM_VERSION

Expand Down Expand Up @@ -677,8 +677,12 @@ async def reset_prefix_cache(raw_request: Request):
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
"""
logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache()
device = None
device_str = raw_request.query_params.get("device")
if device_str is not None:
device = Device[device_str.upper()]
logger.info("Resetting prefix cache with specific %s...", str(device))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger.info("Resetting prefix cache with specific device: %s...", str(device))

await engine_client(raw_request).reset_prefix_cache(device)
return Response(status_code=200)

@router.post("/sleep")
Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import cdiv, kill_process_tree
from vllm.utils import Device, cdiv, kill_process_tree
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
Expand Down Expand Up @@ -398,7 +398,10 @@ async def start_profile(self) -> None:
async def stop_profile(self) -> None:
await self.engine_core.profile_async(False)

async def reset_prefix_cache(self) -> None:
async def reset_prefix_cache(self,
device: Optional[Device] = None) -> None:
if device == Device.CPU:
raise ValueError("Not supported on CPU.")
await self.engine_core.reset_prefix_cache_async()

async def sleep(self, level: int = 1) -> None:
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
Expand Down Expand Up @@ -226,7 +227,7 @@ def start_profile(self):
def stop_profile(self):
self.engine_core.profile(False)

def reset_prefix_cache(self):
def reset_prefix_cache(self, device: Optional[Device] = None):
self.engine_core.reset_prefix_cache()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this isn't implemented for V1... should that be included? (we are aiming to deprecate v0 now)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will implement this for v1 after #13377 merged.


def sleep(self, level: int = 1):
Expand Down