|
12 | 12 | from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
|
13 | 13 | NaiveBlockAllocator)
|
14 | 14 | from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
| 15 | +from vllm.logger import init_logger |
15 | 16 | from vllm.sequence import Sequence
|
16 | 17 |
|
17 | 18 | PrefixHash = int
|
|
21 | 22 | # then we know this block hasn't been accessed yet.
|
22 | 23 | _DEFAULT_LAST_ACCESSED_TIME = -1
|
23 | 24 |
|
| 25 | +logger = init_logger(__name__) |
| 26 | + |
24 | 27 |
|
25 | 28 | class BlockTracker:
|
26 | 29 | """Used to track the status of a block inside the prefix caching allocator
|
@@ -105,7 +108,8 @@ def __init__(
|
105 | 108 |
|
106 | 109 | # Evitor used to maintain how we want to handle those computed blocks
|
107 | 110 | # if we find memory pressure is high.
|
108 |
| - self.evictor: Evictor = make_evictor(eviction_policy) |
| 111 | + self.eviction_policy = eviction_policy |
| 112 | + self.evictor: Evictor = make_evictor(self.eviction_policy) |
109 | 113 |
|
110 | 114 | # We share the refcounter between allocators. This allows us to promote
|
111 | 115 | # blocks originally allocated in the hashless allocator to immutable
|
@@ -428,6 +432,44 @@ def all_block_ids(self) -> FrozenSet[int]:
|
428 | 432 | def get_prefix_cache_hit_rate(self) -> float:
|
429 | 433 | return self.metric_data.get_hit_rate()
|
430 | 434 |
|
| 435 | + def reset_prefix_cache(self) -> bool: |
| 436 | + """Reset prefix cache. This function may be used in RLHF |
| 437 | + flows to invalid prefix caching after the weights are updated, |
| 438 | + or used for resetting prefix caching status for benchmarking. |
| 439 | +
|
| 440 | + Returns: |
| 441 | + bool: True if the prefix cache is successfully reset, |
| 442 | + False otherwise. |
| 443 | + """ |
| 444 | + num_used_blocks = (self.get_num_total_blocks() - |
| 445 | + self.get_num_free_blocks()) |
| 446 | + if num_used_blocks > 0: |
| 447 | + logger.warning( |
| 448 | + "Failed to reset prefix cache because some " |
| 449 | + "blocks (%d) are not freed yet", num_used_blocks) |
| 450 | + return False |
| 451 | + |
| 452 | + # Free all blocks in the evictor. |
| 453 | + while (block_id := |
| 454 | + self._maybe_allocate_evicted_block_id()) is not None: |
| 455 | + self._hashless_allocator.free_block_id(block_id) |
| 456 | + |
| 457 | + # Should not have any cached blocks because all blocks are evicted. |
| 458 | + assert not self._cached_blocks |
| 459 | + |
| 460 | + # Reset the evictor. |
| 461 | + self.evictor = make_evictor(self.eviction_policy) |
| 462 | + |
| 463 | + # Reset the block tracker. |
| 464 | + for block_id in self._block_tracker: |
| 465 | + self._block_tracker[block_id] = BlockTracker() |
| 466 | + |
| 467 | + # Reset the metrics. |
| 468 | + self.metric_data = CacheMetricData() |
| 469 | + |
| 470 | + logger.info("Successfully reset prefix cache") |
| 471 | + return True |
| 472 | + |
431 | 473 | def is_block_cached(self, block: Block) -> bool:
|
432 | 474 | assert block.content_hash is not None
|
433 | 475 | return block.content_hash in self._cached_blocks
|
|
0 commit comments