diff --git a/python/ray/llm/_internal/serve/deployments/llm/llm_server.py b/python/ray/llm/_internal/serve/deployments/llm/llm_server.py index 674a512d5638..9e594563b456 100644 --- a/python/ray/llm/_internal/serve/deployments/llm/llm_server.py +++ b/python/ray/llm/_internal/serve/deployments/llm/llm_server.py @@ -694,7 +694,6 @@ def as_deployment( @serve.deployment( - # TODO make this configurable autoscaling_config={ "min_replicas": 1, "initial_replicas": 1, diff --git a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py b/python/ray/llm/_internal/serve/request_router/prefix_aware/prefix_tree.py similarity index 76% rename from python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py rename to python/ray/llm/_internal/serve/request_router/prefix_aware/prefix_tree.py index 3cba4f5223d8..b7150b34f62a 100644 --- a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py +++ b/python/ray/llm/_internal/serve/request_router/prefix_aware/prefix_tree.py @@ -1,13 +1,17 @@ from __future__ import annotations +import asyncio import logging import os from threading import RLock from typing import Any, Dict, List, Optional, Tuple import ray +from ray.serve._private.constants import ( + SERVE_LOGGER_NAME, +) -logger = logging.getLogger(__name__) +logger = logging.getLogger(SERVE_LOGGER_NAME) class Node: @@ -86,12 +90,13 @@ def __init__(self) -> None: # Root is always the head of the LRU list for each tenant. self.root: Node = Node() - # Tracks total character count per tenant. Can be used by the replica scheduler to determine which tenant to evict, and by how much. + # Tracks total character count per tenant. Can be used by the replica request router to determine which tenant to evict, and by how much. # Also uses the keys to track the active tenants in the tree. self.tenant_to_char_count: Dict[str, int] = {} # LRU tracking - root is always the head, tail is the least recently used. self.tenant_to_lru_tail: Dict[str, Optional[Node]] = {} + self._eviction_task: Optional[asyncio.Task] = None @staticmethod def _shared_prefix_count(a: str, b: str) -> int: @@ -113,6 +118,8 @@ def _get_lru_chain(self, tenant: str) -> List[Node]: Note: This method is intended to be used only in tests. """ with self.lock: + if tenant not in self.tenant_to_char_count: + return [] nodes = [] current_node = self.root while current_node: @@ -120,27 +127,6 @@ def _get_lru_chain(self, tenant: str) -> List[Node]: current_node = current_node.tenant_to_older_node.get(tenant) return nodes - def _add_tenant(self, tenant: str) -> None: - """ - Add a new tenant to the tree. - - If the tenant already exists, this is a no-op with a warning log. - - Args: - tenant: Tenant to add - """ - with self.lock: - if tenant in self.tenant_to_char_count: - logger.warning(f"Tenant '{tenant}' already exists. No action taken.") - return - - self.tenant_to_char_count[tenant] = 0 - self.tenant_to_lru_tail[tenant] = self.root - - # Initialize the root node as the head of the LRU list for this tenant - self.root.tenant_to_newer_node[tenant] = None - self.root.tenant_to_older_node[tenant] = None - def _insert_node_into_linked_list( self, node: Node, @@ -153,7 +139,9 @@ def _insert_node_into_linked_list( """ with self.lock: if tenant not in self.tenant_to_char_count: - logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") + logger.debug( + f"[_insert_node_into_linked_list] Tenant '{tenant}' does not exist. No action taken." + ) return # Skip if node is the root @@ -178,7 +166,9 @@ def _remove_node_from_linked_list(self, node: Node, tenant: str) -> None: """ with self.lock: if tenant not in self.tenant_to_char_count: - logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") + logger.debug( + f"[_remove_node_from_linked_list] Tenant '{tenant}' does not exist. No action taken." + ) return # Skip if node is the root @@ -216,11 +206,13 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: """ with self.lock: if tenant not in self.tenant_to_char_count: - logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") + logger.debug( + f"[_remove_tenant_single_node] Tenant '{tenant}' does not exist. No action taken." + ) return 0 if tenant not in node.tenant_to_last_access_time: - logger.warning( - f"Tenant '{tenant}' does not have node '{node.text}'. No action taken." + logger.debug( + f"[_remove_tenant_single_node] Tenant '{tenant}' does not have node '{node.text}'. No action taken." ) return 0 @@ -239,11 +231,38 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: return removed_chars_len + def add_tenants(self, tenants: List[str], time_s: float) -> None: + """ + Add multiple new tenants to the tree. Also inserts an empty string for each tenant into the tree. + + For each tenant that already exists, a warning is logged and that tenant is skipped. + + Args: + tenants: List of tenants to add + time_s: Current timestamp in seconds + """ + with self.lock: + for tenant in tenants: + if tenant in self.tenant_to_char_count: + logger.debug( + f"[_add_tenants] Tenant '{tenant}' already exists. Skipping." + ) + continue + + self.tenant_to_char_count[tenant] = 0 + self.tenant_to_lru_tail[tenant] = self.root + + # Initialize the root node as the head of the LRU list for this tenant + self.root.tenant_to_newer_node[tenant] = None + self.root.tenant_to_older_node[tenant] = None + self.insert("", tenant, time_s) + def insert(self, text: str, tenant: str, time_s: float) -> None: """ - Insert text into tree for a specific tenant. + Insert text into tree for a specific tenant, but only if the tenant already exists. - If the tenant doesn't already exist in the tree, it will be automatically added. + If the tenant doesn't exist in the tree, this will log a warning and return without + inserting anything. Use add_tenants() first to add a new tenant. Args: text: Text to insert @@ -263,7 +282,10 @@ def insert(self, text: str, tenant: str, time_s: float) -> None: """ with self.lock: if tenant not in self.tenant_to_char_count: - self._add_tenant(tenant) + logger.debug( + f"[_insert] Tenant '{tenant}' does not exist. Use add_tenants() first." + ) + return curr_node: Node = self.root i: int = 0 @@ -373,10 +395,6 @@ def prefix_match( If the list of available tenants doesn't match any tenants in the tree: returns ("", None) When no prefix match is found (does not traverse further than the root node): returns ("", list of available tenants) When a prefix match is found: returns (matched_prefix, list of tenants that own the matched node) - - Note: - A tenant is unable to be returned by prefix_match until it has inserted text into the tree, even if _add_tenant is called. - The replica scheduler is responsible for inserting text into new replicas; it should not only rely on prefix_match to select replicas. """ with self.lock: if available_tenants: @@ -433,38 +451,47 @@ def prefix_match( return matched_text, matched_tenants - def remove_tenant(self, tenant: str) -> int: + def remove_tenants(self, tenants: List[str]) -> Dict[str, int]: """ - Remove a tenant and all its nodes from the tree. - Time complexity: O(n) where n is the number of nodes owned by the tenant. + Remove multiple tenants and all their nodes from the tree. + Time complexity: O(n) where n is the total number of nodes owned by all tenants. Args: - tenant: Tenant to remove + tenants: List of tenants to remove Returns: - Number of characters removed (0 if tenant doesn't exist) + Dictionary mapping each tenant to the number of characters removed + (0 if tenant doesn't exist) """ + chars_removed: Dict[str, int] = {} + with self.lock: - if tenant not in self.tenant_to_char_count: - logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") - return 0 + for tenant in tenants: + if tenant not in self.tenant_to_char_count: + logger.debug( + f"[_remove_tenants] Tenant '{tenant}' does not exist. Skipping." + ) + chars_removed[tenant] = 0 + continue - total_chars_removed: int = 0 + tenant_chars_removed: int = 0 - # Start from the tail and remove all nodes - current_tail = self.tenant_to_lru_tail.get(tenant) - while current_tail: - newer_neighbor = current_tail.tenant_to_newer_node.get(tenant) - total_chars_removed += self._remove_tenant_single_node( - tenant, current_tail - ) - current_tail = newer_neighbor + # Start from the tail and remove all nodes + current_tail = self.tenant_to_lru_tail.get(tenant) + while current_tail: + newer_neighbor = current_tail.tenant_to_newer_node.get(tenant) + tenant_chars_removed += self._remove_tenant_single_node( + tenant, current_tail + ) + current_tail = newer_neighbor - # Clean up tenant references - self.tenant_to_char_count.pop(tenant, None) - self.tenant_to_lru_tail.pop(tenant, None) + # Clean up tenant references + self.tenant_to_char_count.pop(tenant, None) + self.tenant_to_lru_tail.pop(tenant, None) - return total_chars_removed + chars_removed[tenant] = tenant_chars_removed + + return chars_removed def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: """ @@ -485,14 +512,14 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: """ with self.lock: if tenant not in self.tenant_to_char_count: - logger.warning( - f"Cannot evict tenant '{tenant}': tenant does not exist. No action taken." + logger.debug( + f"[_evict_tenant_by_lru] Cannot evict tenant '{tenant}': tenant does not exist. No action taken." ) return 0 if self.tenant_to_char_count[tenant] < min_remove_size: - logger.warning( - f"Cannot evict {min_remove_size} characters from tenant '{tenant}', which has only " + logger.debug( + f"[_evict_tenant_by_lru] Cannot evict {min_remove_size} characters from tenant '{tenant}', which has only " f"{self.tenant_to_char_count[tenant]} characters. Will remove all available characters." ) min_remove_size = self.tenant_to_char_count[tenant] @@ -525,22 +552,65 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: return total_chars_removed - def get_smallest_tenant(self) -> Optional[str]: + def get_smallest_tenants(self) -> Optional[List[str]]: """ - Get the tenant with the smallest total character count. + Get the tenants with the smallest total character count. Returns: - Tenant with smallest character count, or None if no tenants + Tenants with smallest character count, or None if no tenants """ with self.lock: if not self.tenant_to_char_count: return None - return min( - self.tenant_to_char_count, - key=self.tenant_to_char_count.get, - default=None, - ) + min_count = min(self.tenant_to_char_count.values()) + return [ + tenant + for tenant, count in self.tenant_to_char_count.items() + if count == min_count + ] + + def start_eviction_loop( + self, eviction_threshold: int, eviction_target: int, interval_secs: float + ) -> bool: + """Start a single eviction loop within the actor itself + Parameters: + eviction_threshold: Minimum number of characters a tenant must have to be evicted + eviction_target: The maximum number of characters a tenant should have after eviction + interval_secs: Number of seconds between eviction checks + + Returns: + True if the loop was started, False if it was already running + """ + with self.lock: + if self._eviction_task is None: + self._eviction_task = asyncio.create_task( + self._run_eviction_loop( + eviction_threshold, eviction_target, interval_secs + ) + ) + return True + else: + logger.debug("[_start_eviction_loop] Eviction loop already running") + return False + + async def _run_eviction_loop( + self, eviction_threshold, eviction_target, interval_secs + ): + while True: + await asyncio.sleep(interval_secs) + with self.lock: + for tenant, char_count in self.tenant_to_char_count.items(): + if char_count > eviction_threshold: + excess = char_count - eviction_target + self.evict_tenant_by_lru(tenant, excess) + + def stop_eviction_loop(self): + with self.lock: + if self._eviction_task: + self._eviction_task.cancel() + # self._eviction_task.close() + self._eviction_task = None @ray.remote @@ -551,3 +621,6 @@ def getattr(self, attribute: str) -> Any: Note: This method is intended to be used only in tests. """ return getattr(self, attribute) + + def setattr(self, attribute: str, value: Any) -> None: + setattr(self, attribute, value) diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_aware_request_router.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_aware_request_router.py new file mode 100644 index 000000000000..4ecd45f0dceb --- /dev/null +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_aware_request_router.py @@ -0,0 +1,330 @@ +import asyncio +import time + +import pytest + +import ray +from ray._common.utils import get_or_create_event_loop +from ray.llm._internal.serve.request_router.prefix_aware.prefix_tree import ( + PrefixTreeActor, +) +from ray.serve._private.common import ( + DeploymentHandleSource, + DeploymentID, + RequestMetadata, +) +from ray.serve._private.request_router.common import PendingRequest +from ray.serve._private.request_router.prefix_aware_router import ( + PrefixAwarePow2ReplicaRouter, +) +from ray.serve._private.test_utils import MockTimer +from ray.serve._private.utils import generate_request_id +from ray.serve.tests.unit.test_pow_2_request_router import ( + FakeRunningReplica, +) # Reuse the FakeRunningReplica from the Pow2 test + +TIMER = MockTimer() +DEFAULT_MAX_ONGOING_REQUESTS = 10 + + +# === Fixtures === + + +@pytest.fixture +def tree_actor(): + """Create a fresh PrefixTreeActor instance.""" + actor = PrefixTreeActor.options(name="PrefixTreeActor").remote() + yield actor + ray.kill(actor) + + +@pytest.fixture +def prefix_request_router(tree_actor, request): + """Create a fresh PrefixAwarePow2ReplicaRouter with connected tree_actor.""" + params = getattr(request, "param", {}) + + async def construct_request_router(loop: asyncio.AbstractEventLoop): + request_router = PrefixAwarePow2ReplicaRouter( + deployment_id=DeploymentID(name="TEST_DEPLOYMENT"), + handle_source=DeploymentHandleSource.REPLICA, + use_replica_queue_len_cache=False, + imbalanced_threshold=params.get("imbalanced_threshold", 10), + match_rate_threshold=params.get("match_rate_threshold", 0.1), + do_eviction=params.get("do_eviction", False), + eviction_threshold_chars=params.get("eviction_threshold_chars"), + eviction_target_chars=params.get("eviction_target_chars"), + eviction_interval_secs=params.get("eviction_interval_secs"), + get_curr_time_s=TIMER.time, + tree_actor=tree_actor, + ) + return request_router + + request_router = asyncio.new_event_loop().run_until_complete( + construct_request_router(get_or_create_event_loop()) + ) + + yield request_router + assert request_router.curr_num_routing_tasks == 0 + assert request_router.num_pending_requests == 0 + + +# === Helpers === + + +class PromptRequest: + def __init__(self, prompt: str): + self.prompt = prompt + + +class ChatRequest: + def __init__(self, messages): + self.messages = messages + + +def fake_pending_request(prompt=None, messages=None) -> PendingRequest: + if prompt is not None: + args = [PromptRequest(prompt)] + elif messages is not None: + args = [ChatRequest(messages)] + else: + args = [] + + return PendingRequest( + args=args, + kwargs={}, + metadata=RequestMetadata( + request_id=generate_request_id(), + internal_request_id=generate_request_id(), + multiplexed_model_id="", + ), + created_at=time.time(), + ) + + +# === Tests === +class TestPow2FallbackBehavior: + """Tests fallback to Pow2 when prefix-aware logic should be skipped.""" + + @pytest.mark.asyncio + async def test_fallback_when_no_prompt(self, prefix_request_router): + """No args → prefix logic skipped → falls back to least busy replica.""" + r1 = FakeRunningReplica("r1") + r1.set_queue_len_response(0) + r2 = FakeRunningReplica("r2") + r2.set_queue_len_response(5) + prefix_request_router.update_replicas([r1, r2]) + + tenant_to_char_count = ray.get( + prefix_request_router._tree_actor.getattr.remote("tenant_to_char_count") + ) + assert tenant_to_char_count == { + r1.replica_id.to_full_id_str(): 0, + r2.replica_id.to_full_id_str(): 0, + } + + req = fake_pending_request() + for _ in range(10): + chosen = await prefix_request_router.choose_replica_for_request(req) + assert chosen == r1 + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "prefix_request_router", [{"imbalanced_threshold": 2}], indirect=True + ) + async def test_fallback_when_imbalanced(self, prefix_request_router): + """If load is imbalanced beyond threshold, prefix matching is skipped.""" + r1 = FakeRunningReplica("r1") + r1.set_queue_len_response(0) + r2 = FakeRunningReplica("r2") + r2.set_queue_len_response(10) + prefix_request_router.update_replicas([r1, r2]) + + ray.get( + prefix_request_router._tree_actor.insert.remote( + "hello world", r2.replica_id.to_full_id_str(), time.time() + ) + ) + + tenant_to_char_count = ray.get( + prefix_request_router._tree_actor.getattr.remote("tenant_to_char_count") + ) + assert tenant_to_char_count == { + r1.replica_id.to_full_id_str(): 0, + r2.replica_id.to_full_id_str(): 11, + } + + matched_text, matched_tenants = ray.get( + prefix_request_router._tree_actor.prefix_match.remote("hello world") + ) + assert matched_text == "hello world" + assert matched_tenants == [r2.replica_id.to_full_id_str()] + + req = fake_pending_request(prompt="hello world") + for _ in range(10): + chosen = await prefix_request_router.choose_replica_for_request(req) + # Even though r2 has a higher match rate, it is not chosen because the load is imbalanced + assert chosen == r1 + + +class TestPrefixAwareLogic: + """Tests that exercise actual prefix-aware request routing logic.""" + + @pytest.mark.asyncio + async def test_high_match_rate_selects_matching_replica( + self, prefix_request_router + ): + """High match rate → use matched replica instead of Pow2.""" + r1 = FakeRunningReplica("r1") + r1.set_queue_len_response(0) + r2 = FakeRunningReplica("r2") + r2.set_queue_len_response(0) + prefix_request_router.update_replicas([r1, r2]) + ray.get( + prefix_request_router._tree_actor.insert.remote( + "Hello", r2.replica_id.to_full_id_str(), time.time() + ) + ) + # Verify prefix match and smallest tenants + matched_text, matched_tenants = ray.get( + prefix_request_router._tree_actor.prefix_match.remote("Hello world") + ) + assert matched_text == "Hello" + assert matched_tenants == [r2.replica_id.to_full_id_str()] + + tenant_counts = ray.get( + prefix_request_router._tree_actor.getattr.remote("tenant_to_char_count") + ) + assert tenant_counts[r1.replica_id.to_full_id_str()] == 0 + assert tenant_counts[r2.replica_id.to_full_id_str()] == 5 + + prompt_req = fake_pending_request(prompt="Hello world") + for _ in range(10): + chosen = await prefix_request_router.choose_replica_for_request(prompt_req) + assert chosen == r2 + chat_req = fake_pending_request( + messages=[{"content": "Hello"}, {"content": " world"}] + ) + for _ in range(10): + chosen = await prefix_request_router.choose_replica_for_request(chat_req) + assert chosen == r2 + + @pytest.mark.asyncio + async def test_low_match_rate_uses_smallest_tree(self, prefix_request_router): + """Low match rate → use replica with least total inserted characters.""" + r1 = FakeRunningReplica("r1") + r1.set_queue_len_response(0) + r2 = FakeRunningReplica("r2") + r2.set_queue_len_response(0) + prefix_request_router.update_replicas([r1, r2]) + + # Make r2 "bigger" tenant + ray.get( + prefix_request_router._tree_actor.insert.remote( + "hi", r1.replica_id.to_full_id_str(), time.time() + ) + ) + ray.get( + prefix_request_router._tree_actor.insert.remote( + "longtext", r2.replica_id.to_full_id_str(), time.time() + ) + ) + + # Verify tenant character counts + tenant_counts = ray.get( + prefix_request_router._tree_actor.getattr.remote("tenant_to_char_count") + ) + assert tenant_counts[r1.replica_id.to_full_id_str()] == 2 # "hi" + assert tenant_counts[r2.replica_id.to_full_id_str()] == 8 # "longtext" + + prompt_req = fake_pending_request(prompt="z") + for _ in range(10): + # Both tenants have 0% match rate, so the smaller tenant (r1) is chosen + assert ( + await prefix_request_router.choose_replica_for_request(prompt_req) == r1 + ) + + chat_req = fake_pending_request(messages=[{"content": "z"}]) + for _ in range(10): + # Both tenants have 0% match rate, so the smaller tenant (r1) is chosen + assert ( + await prefix_request_router.choose_replica_for_request(chat_req) == r1 + ) + + +class TestEvictionBehavior: + """Tests for prefix tree eviction behavior.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "prefix_request_router", + [ + { + "do_eviction": True, + "eviction_threshold_chars": 10, + "eviction_target_chars": 5, + "eviction_interval_secs": 1.0, + } + ], + indirect=True, + ) + async def test_eviction_task_creation(self, prefix_request_router): + """Test that eviction task is only created after update_replicas.""" + # Before update_replicas + assert not prefix_request_router._eviction_loop_running + + # After update_replicas + r1 = FakeRunningReplica("r1") + prefix_request_router.update_replicas([r1]) + assert prefix_request_router._eviction_loop_running + + # After stop_eviction_loop + ray.get(prefix_request_router._tree_actor.stop_eviction_loop.remote()) + await asyncio.sleep(0.1) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "prefix_request_router", + [ + { + "do_eviction": True, + "eviction_threshold_chars": 10, + "eviction_target_chars": 5, + "eviction_interval_secs": 1.0, + } + ], + indirect=True, + ) + async def test_eviction_threshold_behavior(self, prefix_request_router): + """Test that eviction reduces tree size below threshold after interval.""" + r1 = FakeRunningReplica("r1") + prefix_request_router.update_replicas([r1]) + + # Insert text that exceeds eviction_threshold_chars + ray.get( + prefix_request_router._tree_actor.insert.remote( + "verylongtext", r1.replica_id.to_full_id_str(), time.time() + ) + ) + ray.get( + prefix_request_router._tree_actor.insert.remote( + "anotherlongtext", r1.replica_id.to_full_id_str(), time.time() + ) + ) + + # Verify initial size exceeds eviction_threshold_chars + tenant_counts = ray.get( + prefix_request_router._tree_actor.getattr.remote("tenant_to_char_count") + ) + assert tenant_counts[r1.replica_id.to_full_id_str()] > 10 + + # Wait for eviction interval + await asyncio.sleep(1.1) + + # Verify size is reduced below eviction_target_chars + tenant_counts = ray.get( + prefix_request_router._tree_actor.getattr.remote("tenant_to_char_count") + ) + assert tenant_counts[r1.replica_id.to_full_id_str()] <= 5 + + ray.get(prefix_request_router._tree_actor.stop_eviction_loop.remote()) + await asyncio.sleep(0.1) diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py index 9e83884c41a5..9b62e70b7bf4 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -1,9 +1,10 @@ +import asyncio from typing import List, Set import pytest import ray -from ray.llm._internal.serve.replica_scheduler.prefix_aware.prefix_tree import ( +from ray.llm._internal.serve.request_router.prefix_aware.prefix_tree import ( Node, PrefixTree, PrefixTreeActor, @@ -52,33 +53,88 @@ def test_initial_state(self, tree: PrefixTree) -> None: assert tree.root.edge_label_to_child == {} def test_add_tenant(self, tree: PrefixTree) -> None: - """Test adding a new tenant via _add_tenant.""" - tree._add_tenant("tenant_1") + """Test adding a new tenant via add_tenants.""" + tree.add_tenants(["tenant_1"], 0) assert tree.tenant_to_char_count == {"tenant_1": 0} assert tree.tenant_to_lru_tail.get("tenant_1") == tree.root - # _add_tenant itself doesn't update root's access time for the tenant. - assert tree.root.tenant_to_last_access_time == {} + assert tree.root.tenant_to_last_access_time == {"tenant_1": 0} assert get_lru_texts_from_tree(tree, "tenant_1") == [""] def test_add_existing_tenant_noop(self, tree: PrefixTree) -> None: - """Test that adding an existing tenant via _add_tenant is a no-op.""" - tree._add_tenant("tenant_1") + """Test that adding an existing tenant via add_tenants is a no-op.""" + tree.add_tenants(["tenant_1"], 0) assert tree.tenant_to_char_count == {"tenant_1": 0} assert tree.tenant_to_lru_tail.get("tenant_1") == tree.root - assert tree.root.tenant_to_last_access_time == {} + assert tree.root.tenant_to_last_access_time == {"tenant_1": 0} assert get_lru_texts_from_tree(tree, "tenant_1") == [""] - tree._add_tenant("tenant_1") # Add again + tree.add_tenants(["tenant_1"], 0) # Add again assert tree.tenant_to_char_count == {"tenant_1": 0} assert tree.tenant_to_lru_tail.get("tenant_1") == tree.root - assert tree.root.tenant_to_last_access_time == {} + assert tree.root.tenant_to_last_access_time == {"tenant_1": 0} assert get_lru_texts_from_tree(tree, "tenant_1") == [""] + def test_add_multiple_tenants(self, tree: PrefixTree) -> None: + """Test adding multiple tenants at once.""" + tree.add_tenants(["tenant_1", "tenant_2", "tenant_3"], 0) + + assert tree.tenant_to_char_count == { + "tenant_1": 0, + "tenant_2": 0, + "tenant_3": 0, + } + for tenant in ["tenant_1", "tenant_2", "tenant_3"]: + assert tree.tenant_to_lru_tail.get(tenant) == tree.root + assert tree.root.tenant_to_newer_node.get(tenant) is None + assert tree.root.tenant_to_older_node.get(tenant) is None + assert tree.root.tenant_to_last_access_time == { + "tenant_1": 0, + "tenant_2": 0, + "tenant_3": 0, + } + assert get_lru_texts_from_tree(tree, tenant) == [""] + + def test_add_multiple_tenants_with_existing(self, tree: PrefixTree) -> None: + """Test adding multiple tenants when some already exist.""" + tree.add_tenants(["tenant_1"], 0) + assert tree.root.tenant_to_last_access_time == {"tenant_1": 0} + assert tree.tenant_to_char_count == {"tenant_1": 0} + assert "tenant_1" in tree.tenant_to_lru_tail + + # Add a mix of new and existing tenants + tree.add_tenants(["tenant_1", "tenant_2", "tenant_3"], 0) + # Existing tenants should remain unchanged + assert tree.root.tenant_to_last_access_time == { + "tenant_1": 0, + "tenant_2": 0, + "tenant_3": 0, + } + assert tree.tenant_to_char_count == { + "tenant_1": 0, + "tenant_2": 0, + "tenant_3": 0, + } + assert all( + tenant in tree.tenant_to_lru_tail + for tenant in ["tenant_1", "tenant_2", "tenant_3"] + ) + class TestPrefixTreeInsert: + def test_insert_non_existent_tenant(self, tree: PrefixTree) -> None: + """Test inserting a string for a non-existent tenant fails.""" + # Insert without adding tenant first + tree.insert("hello", "nonexistent", 1) + + # Verify insert did nothing since tenant doesn't exist + assert "nonexistent" not in tree.tenant_to_char_count + assert get_lru_texts_from_tree(tree, "nonexistent") == [] + assert "h" not in tree.root.edge_label_to_child + def test_insert_single_string(self, tree: PrefixTree) -> None: - """Test inserting a single string, which also adds a new tenant.""" + """Test inserting a single string after adding a tenant.""" + tree.add_tenants(["tenant_1"], 0) tree.insert("hello", "tenant_1", 1) assert tree.tenant_to_char_count == {"tenant_1": 5} assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "hello"] @@ -95,6 +151,7 @@ def test_insert_single_string(self, tree: PrefixTree) -> None: def test_insert_duplicate_string(self, tree: PrefixTree) -> None: """Test inserting a duplicate string for the same tenant.""" + tree.add_tenants(["tenant_1"], 0) tree.insert("hello", "tenant_1", 1) # Initial insert tree.insert("hello", "tenant_1", 1) # Duplicate insert with the same timestamp @@ -122,6 +179,7 @@ def test_insert_duplicate_string(self, tree: PrefixTree) -> None: def test_insert_multiple_tenants(self, tree: PrefixTree) -> None: """Test inserting the same string for different tenants.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("hello", "tenant_1", 1) tree.insert("hello", "tenant_2", 2) @@ -135,6 +193,7 @@ def test_insert_multiple_tenants(self, tree: PrefixTree) -> None: def test_insert_node_split(self, tree: PrefixTree) -> None: """Test insertion that causes an existing node to split due to differing suffixes.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("helloworld", "tenant_1", 1) tree.insert("hellothere", "tenant_2", 2) # "hello" is common prefix @@ -157,6 +216,7 @@ def test_insert_node_split(self, tree: PrefixTree) -> None: def test_insert_longer_string_with_shared_prefix(self, tree: PrefixTree) -> None: """Test inserting a longer string that shares a prefix with an existing node string.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("hello", "tenant_1", 1) tree.insert("helloworld", "tenant_2", 2) # "hello" is prefix of "helloworld" @@ -189,6 +249,7 @@ def test_insert_longer_string_with_shared_prefix(self, tree: PrefixTree) -> None def test_insert_shorter_string_with_shared_prefix(self, tree: PrefixTree) -> None: """Test inserting a shorter string that is a prefix of an existing longer string, causing split.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("helloworld", "tenant_1", 1) tree.insert( "hello", "tenant_2", 2 @@ -217,6 +278,7 @@ def test_prefix_match_empty_tree(self, tree: PrefixTree) -> None: def test_prefix_match_no_match(self, tree: PrefixTree) -> None: """Test prefix_match for a non-matching prefix returns empty string and all tenants.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("hello", "tenant_1", 1) tree.insert("world", "tenant_2", 2) matched_text, matched_tenants = tree.prefix_match("foobar") @@ -228,6 +290,7 @@ def test_prefix_match_query_longer_than_stored_strings( self, tree: PrefixTree ) -> None: """Test prefix_match where query is longer than any stored string but matches a full path.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("helloworld", "tenant_1", 1) tree.insert("hellothere", "tenant_2", 2) matched_text, matched_tenants = tree.prefix_match("hellothereextra") @@ -236,6 +299,7 @@ def test_prefix_match_query_longer_than_stored_strings( def test_prefix_match_exact_match(self, tree: PrefixTree) -> None: """Test prefix_match with an exact match for a single tenant.""" + tree.add_tenants(["tenant_1"], 0) tree.insert("hello", "tenant_1", 1) matched_text, matched_tenants = tree.prefix_match("hello") assert matched_text == "hello" @@ -243,6 +307,7 @@ def test_prefix_match_exact_match(self, tree: PrefixTree) -> None: def test_prefix_match_partial_match(self, tree: PrefixTree) -> None: """Test prefix_match with a partial query matching the longest common part of a branch.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("apple", "tenant_1", 1) tree.insert("apricot", "tenant_2", 2) matched_text, matched_tenants = tree.prefix_match("application") @@ -251,16 +316,47 @@ def test_prefix_match_partial_match(self, tree: PrefixTree) -> None: def test_prefix_match_with_tenant_filter(self, tree: PrefixTree) -> None: """Test prefix_match with a tenant filter selecting a specific branch.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("apple", "tenant_1", 1) tree.insert("apricot", "tenant_2", 2) matched_text, matched_tenants = tree.prefix_match("application", ["tenant_2"]) assert matched_text == "ap" assert matched_tenants == ["tenant_2"] + def test_prefix_match_with_shared_prefix_tenant_filter( + self, tree: PrefixTree + ) -> None: + """Test prefix_match with a tenant filter when one tenant has a prefix of a longer string.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) + tree.insert("apple", "tenant_1", 1) + tree.insert("applepie", "tenant_2", 2) + + # Match the longer string but only allow tenant_1 + matched_text, matched_tenants = tree.prefix_match("applepie", ["tenant_1"]) + + # Should only match up to "apple" as that's what tenant_1 owns + assert matched_text == "apple" + assert matched_tenants == ["tenant_1"] + + # Verify that using both tenants would match the full string for tenant_2 only + matched_text, matched_tenants = tree.prefix_match( + "applepie", ["tenant_1", "tenant_2"] + ) + assert matched_text == "applepie" + assert matched_tenants == ["tenant_2"] + + # And both tenants should be returned for "apple" + matched_text, matched_tenants = tree.prefix_match( + "apple", ["tenant_1", "tenant_2"] + ) + assert matched_text == "apple" + assert set(matched_tenants) == {"tenant_1", "tenant_2"} + def test_prefix_match_with_non_existent_tenant_filter( self, tree: PrefixTree ) -> None: """Test prefix_match with a filter for a non-existent tenant returns no match.""" + tree.add_tenants(["tenant_1"], 0) tree.insert("apple", "tenant_1", 1) matched_text, matched_tenants = tree.prefix_match( "application", ["non_existent_tenant"] @@ -272,6 +368,7 @@ def test_prefix_match_with_non_existent_tenant_filter( class TestPrefixTreeRemove: def test_remove_single_leaf_node_pruned(self, tree: PrefixTree) -> None: """Test _remove_tenant_single_node for a leaf node; node should be pruned.""" + tree.add_tenants(["tenant_1"], 0) tree.insert("hello", "tenant_1", 1) hello_node = tree.root.edge_label_to_child["h"] assert hello_node.tenant_to_last_access_time == {"tenant_1": 1} @@ -286,6 +383,7 @@ def test_remove_single_leaf_node_pruned(self, tree: PrefixTree) -> None: def test_remove_single_leaf_node_not_pruned(self, tree: PrefixTree) -> None: """Test _remove_tenant_single_node for a leaf node; node should not be pruned.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("hello", "tenant_1", 1) tree.insert("hello", "tenant_2", 2) hello_node = tree.root.edge_label_to_child["h"] @@ -303,6 +401,7 @@ def test_remove_single_node_with_non_existent_tenant( self, tree: PrefixTree ) -> None: """Test _remove_tenant_single_node for a non-existent tenant is a no-op.""" + tree.add_tenants(["tenant_1"], 0) tree.insert("hello", "tenant_1", 1) hello_node = tree.root.edge_label_to_child["h"] removed_chars = tree._remove_tenant_single_node( @@ -314,6 +413,7 @@ def test_remove_single_node_with_non_matching_tenant( self, tree: PrefixTree ) -> None: """Test _remove_tenant_single_node if node doesn't belong to specified tenant is a no-op.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("hello", "tenant_1", 1) tree.insert("world", "tenant_2", 2) # Node for tenant_2 hello_node = tree.root.edge_label_to_child["h"] # Belongs to tenant_1 @@ -324,11 +424,12 @@ def test_remove_single_node_with_non_matching_tenant( def test_remove_tenant(self, tree: PrefixTree) -> None: """Test remove_tenant for a tree with multiple tenants only removes the specified tenant.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("hello", "tenant_1", 1) tree.insert("foobar", "tenant_1", 2) tree.insert("helloworld", "tenant_2", 3) - removed_chars = tree.remove_tenant("tenant_1") - assert removed_chars == 11 + removed_chars = tree.remove_tenants(["tenant_1"]) + assert removed_chars == {"tenant_1": 11} hello_node = tree.root.edge_label_to_child["h"] assert hello_node.tenant_to_last_access_time == {"tenant_2": 3} assert tree.tenant_to_char_count == {"tenant_2": 10} @@ -338,30 +439,77 @@ def test_remove_tenant(self, tree: PrefixTree) -> None: def test_remove_non_existent_tenant(self, tree: PrefixTree) -> None: """Test remove_tenant for a non-existent tenant returns 0.""" + tree.add_tenants(["tenant_1"], 0) tree.insert("hello", "tenant_1", 1) - removed_chars = tree.remove_tenant("non_existent_tenant") - assert removed_chars == 0 + removed_chars = tree.remove_tenants(["non_existent_tenant"]) + assert removed_chars == {"non_existent_tenant": 0} def test_remove_tenant_prunes_nodes(self, tree: PrefixTree) -> None: """Test remove_tenant prunes nodes that become tenant-less and childless.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("helloworld", "tenant_1", 1) # Creates "helloworld" tree.insert( "hellothere", "tenant_2", 2 ) # Splits into "hello" -> "world" and "hello" -> "there" - tree.remove_tenant( - "tenant_1" - ) # "world" node should be pruned. "hello" and "there" remain for tenant_2. + tree.remove_tenants(["tenant_1"]) + # "world" node should be pruned. "hello" and "there" remain for tenant_2. hello_node = tree.root.edge_label_to_child["h"] - assert set(hello_node.edge_label_to_child.keys()) == { - "t" - } # "w" (world) child is gone + assert set(hello_node.edge_label_to_child.keys()) == {"t"} assert hello_node.edge_label_to_child["t"].text == "there" assert hello_node.edge_label_to_child["t"].tenant_to_last_access_time == { "tenant_2": 2 } + def test_remove_tenants(self, tree: PrefixTree) -> None: + """Test remove_tenants for multiple tenants with different structures.""" + tree.add_tenants(["tenant_1", "tenant_2", "tenant_3"], 0) + tree.insert("hello", "tenant_1", 1) # 5 chars + tree.insert("foobar", "tenant_1", 2) # 6 chars + tree.insert("helloworld", "tenant_2", 3) # 10 chars + tree.insert("test", "tenant_3", 4) # 4 chars + + removed_chars = tree.remove_tenants(["tenant_1", "tenant_3"]) + + # Check return value contains correct char counts + assert removed_chars == {"tenant_1": 11, "tenant_3": 4} + + # Check tree state is correct + assert "tenant_1" not in tree.tenant_to_char_count + assert "tenant_3" not in tree.tenant_to_char_count + assert "tenant_2" in tree.tenant_to_char_count + assert tree.tenant_to_char_count == {"tenant_2": 10} + + # Check nodes are correctly maintained + assert ( + "h" in tree.root.edge_label_to_child + ) # hello node still exists for tenant_2 + assert "t" not in tree.root.edge_label_to_child # test node removed + assert "f" not in tree.root.edge_label_to_child # foobar node removed + + # Check LRU structure + assert set(tree.tenant_to_lru_tail.keys()) == {"tenant_2"} + tenant_2_lru_texts = get_lru_texts_from_tree(tree, "tenant_2") + assert tenant_2_lru_texts == ["", "world", "hello"] + + def test_remove_tenants_with_nonexistent(self, tree: PrefixTree) -> None: + """Test remove_tenants with a mix of existing and non-existent tenants.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) + tree.insert("hello", "tenant_1", 1) + tree.insert("world", "tenant_2", 2) + + removed_chars = tree.remove_tenants(["tenant_1", "nonexistent", "alsonotfound"]) + + # Check return value + assert removed_chars == {"tenant_1": 5, "nonexistent": 0, "alsonotfound": 0} + + # Check tree state + assert "tenant_1" not in tree.tenant_to_char_count + assert tree.tenant_to_char_count == {"tenant_2": 5} + assert "h" not in tree.root.edge_label_to_child # hello node removed + assert "w" in tree.root.edge_label_to_child # world node still exists + class TestPrefixTreeEviction: def test_eviction_non_existent_tenant(self, tree: PrefixTree) -> None: @@ -370,6 +518,7 @@ def test_eviction_non_existent_tenant(self, tree: PrefixTree) -> None: def test_eviction_exact_min_remove_size_single_node(self, tree: PrefixTree) -> None: """Test evicting exactly min_remove_size characters from a single oldest node.""" + tree.add_tenants(["tenant_1"], 0) tree.insert("a", "tenant_1", 1) # Oldest (1 char) tree.insert("bb", "tenant_1", 2) tree.insert("ccc", "tenant_1", 3) @@ -384,6 +533,7 @@ def test_eviction_exceed_min_remove_size_single_node( self, tree: PrefixTree ) -> None: """Test evicting more than min_remove_size characters from a single oldest node.""" + tree.add_tenants(["tenant_1"], 0) tree.insert("aaa", "tenant_1", 1) # Oldest (2 chars) tree.insert("bb", "tenant_1", 2) tree.insert("c", "tenant_1", 3) @@ -396,6 +546,7 @@ def test_eviction_exceed_min_remove_size_single_node( def test_eviction_multiple_nodes(self, tree: PrefixTree) -> None: """Test evicting multiple oldest nodes to meet min_remove_size.""" + tree.add_tenants(["tenant_1"], 0) tree.insert("a", "tenant_1", 1) # Oldest (1 char) tree.insert("bb", "tenant_1", 2) # Next oldest (2 chars) tree.insert("ccc", "tenant_1", 3) @@ -408,6 +559,7 @@ def test_eviction_multiple_nodes(self, tree: PrefixTree) -> None: def test_eviction_same_timestamps(self, tree: PrefixTree) -> None: """Test evicting more than min_remove_size if multiple nodes share the oldest timestamp.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("helloworld", "tenant_1", 1) tree.insert("hellothere", "tenant_2", 2) assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "hello", "world"] @@ -422,6 +574,7 @@ def test_eviction_same_timestamps(self, tree: PrefixTree) -> None: def test_eviction_insufficient_chars_evicts_all(self, tree: PrefixTree) -> None: """Test evicting when min_remove_size is larger than available; evicts all.""" + tree.add_tenants(["tenant_1"], 0) tree.insert("xyz", "tenant_1", 1) # 3 chars available evicted_count = tree.evict_tenant_by_lru("tenant_1", 10) assert evicted_count == 3 @@ -429,27 +582,40 @@ def test_eviction_insufficient_chars_evicts_all(self, tree: PrefixTree) -> None: assert get_lru_texts_from_tree(tree, "tenant_1") == [""] -class TestPrefixTreeGetSmallestTenant: - def test_get_smallest_tenant(self, tree: PrefixTree) -> None: - """Test get_smallest_tenant identifies the tenant with the fewest characters.""" +class TestPrefixTreeGetSmallestTenants: + """Tests for the get_smallest_tenants method.""" + + def test_get_smallest_tenants(self, tree: PrefixTree) -> None: + """Test get_smallest_tenants identifies the tenant with the fewest characters.""" + tree.add_tenants(["tenant_1", "tenant_2", "tenant_3"], 0) tree.insert("aaaa", "tenant_1", 1) # 4 chars tree.insert("bb", "tenant_2", 2) # 2 chars tree.insert("c", "tenant_3", 3) # 1 char - assert tree.get_smallest_tenant() == "tenant_3" + smallest_tenants = tree.get_smallest_tenants() + assert smallest_tenants == ["tenant_3"] - def test_get_smallest_tenant_empty_tree(self, tree: PrefixTree) -> None: - """Test get_smallest_tenant on an empty tree returns None.""" - assert tree.get_smallest_tenant() is None + def test_get_smallest_tenants_empty_tree(self, tree: PrefixTree) -> None: + """Test get_smallest_tenants on an empty tree returns None.""" + assert tree.get_smallest_tenants() is None - def test_get_smallest_tenant_after_update(self, tree: PrefixTree) -> None: - """Test get_smallest_tenant after removing the current smallest tenant.""" + def test_get_smallest_tenants_after_update(self, tree: PrefixTree) -> None: + """Test get_smallest_tenants after removing the current smallest tenant.""" + tree.add_tenants(["tenant_1", "tenant_2", "tenant_3"], 0) tree.insert("aaaa", "tenant_1", 1) tree.insert("bb", "tenant_2", 2) tree.insert("c", "tenant_3", 3) - tree.remove_tenant("tenant_3") # Remove "c" (1 char) - assert ( - tree.get_smallest_tenant() == "tenant_2" - ) # "bb" (2 chars) is now smallest + tree.remove_tenants(["tenant_3"]) # Remove "c" (1 char) + smallest_tenants = tree.get_smallest_tenants() + assert smallest_tenants == ["tenant_2"] # "bb" (2 chars) is now smallest + + def test_get_smallest_tenants_with_ties(self, tree: PrefixTree) -> None: + """Test get_smallest_tenants when multiple tenants have the same minimum count.""" + tree.add_tenants(["tenant_1", "tenant_2", "tenant_3"], 0) + tree.insert("aa", "tenant_1", 1) # 2 chars + tree.insert("bb", "tenant_2", 2) # 2 chars + tree.insert("cccc", "tenant_3", 3) # 4 chars + smallest_tenants = tree.get_smallest_tenants() + assert set(smallest_tenants) == {"tenant_1", "tenant_2"} class TestPrefixTreeComprehensive: @@ -457,6 +623,7 @@ class TestPrefixTreeComprehensive: def test_tree_structure_multiple_insertions(self, tree: PrefixTree) -> None: """Test tree structure after multiple insertions.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("helloworld", "tenant_1", 1) tree.insert("hellothere", "tenant_2", 2) tree.insert("hellothomas", "tenant_2", 3) @@ -508,6 +675,7 @@ def test_tree_structure_multiple_insertions(self, tree: PrefixTree) -> None: def test_multiple_evictions_maintains_lru_order(self, tree: PrefixTree) -> None: """Test multiple evictions maintain LRU order.""" + tree.add_tenants(["tenant_1", "tenant_2"], 0) tree.insert("helloworld", "tenant_1", 1) tree.insert("hellothere", "tenant_2", 2) tree.insert("hellothomas", "tenant_2", 3) @@ -554,10 +722,11 @@ class TestPrefixTreeActorComprehensive: async def test_tree_structure_multiple_insertions_actor( self, tree_actor: PrefixTreeActor ) -> None: - # Insert strings in specified order - tree_actor.insert.remote("helloworld", "tenant_1", 1) - tree_actor.insert.remote("hellothere", "tenant_2", 2) - tree_actor.insert.remote("hellothomas", "tenant_2", 3) + # Add tenants and insert strings in specified order + ray.get(tree_actor.add_tenants.remote(["tenant_1", "tenant_2"], 0)) + ray.get(tree_actor.insert.remote("helloworld", "tenant_1", 1)) + ray.get(tree_actor.insert.remote("hellothere", "tenant_2", 2)) + ray.get(tree_actor.insert.remote("hellothomas", "tenant_2", 3)) assert await get_lru_texts_from_tree_actor(tree_actor, "tenant_1") == [ "", "hello", @@ -613,9 +782,11 @@ async def test_multiple_evictions_maintains_lru_order_actor( self, tree_actor: PrefixTreeActor ) -> None: """Test multiple evictions maintain LRU order.""" - tree_actor.insert.remote("helloworld", "tenant_1", 1) - tree_actor.insert.remote("hellothere", "tenant_2", 2) - tree_actor.insert.remote("hellothomas", "tenant_2", 3) + # Add tenants and insert test data + ray.get(tree_actor.add_tenants.remote(["tenant_1", "tenant_2"], 0)) + ray.get(tree_actor.insert.remote("helloworld", "tenant_1", 1)) + ray.get(tree_actor.insert.remote("hellothere", "tenant_2", 2)) + ray.get(tree_actor.insert.remote("hellothomas", "tenant_2", 3)) assert ray.get(tree_actor.getattr.remote("tenant_to_char_count")) == { "tenant_1": 10, "tenant_2": 14, @@ -634,7 +805,7 @@ async def test_multiple_evictions_maintains_lru_order_actor( ] # Eviction 1 (tenant_1): min_remove_size=1. "hello" and "world" removed. - evicted_1 = await tree_actor.evict_tenant_by_lru.remote("tenant_1", 1) + evicted_1 = ray.get(tree_actor.evict_tenant_by_lru.remote("tenant_1", 1)) assert evicted_1 == 10 assert ray.get(tree_actor.getattr.remote("tenant_to_char_count")) == { "tenant_1": 0, @@ -650,7 +821,7 @@ async def test_multiple_evictions_maintains_lru_order_actor( ] # T2 unchanged # Eviction 2 (tenant_2): min_remove_size=1. "ere" is oldest timestamp, removed. - evicted_2 = await tree_actor.evict_tenant_by_lru.remote("tenant_2", 1) + evicted_2 = ray.get(tree_actor.evict_tenant_by_lru.remote("tenant_2", 1)) assert evicted_2 == 3 # "ere" is 3 chars assert ray.get(tree_actor.getattr.remote("tenant_to_char_count")) == { "tenant_1": 0, @@ -664,7 +835,7 @@ async def test_multiple_evictions_maintains_lru_order_actor( ] # Eviction 3 (tenant_2): min_remove_size=1. "omas"(ts3), "th"(ts3), "hello"(ts3) removed. - evicted_3 = await tree_actor.evict_tenant_by_lru.remote("tenant_2", 1) + evicted_3 = ray.get(tree_actor.evict_tenant_by_lru.remote("tenant_2", 1)) assert evicted_3 == 11 # 4+2+5 chars assert ray.get(tree_actor.getattr.remote("tenant_to_char_count")) == { "tenant_1": 0, @@ -673,6 +844,154 @@ async def test_multiple_evictions_maintains_lru_order_actor( assert await get_lru_texts_from_tree_actor(tree_actor, "tenant_2") == [""] +@pytest.mark.asyncio +class TestPrefixTreeActorEvictionLoop: + """Tests for the automatic eviction loop in PrefixTreeActor""" + + async def test_eviction_loop_triggers_automatically( + self, tree_actor: PrefixTreeActor + ) -> None: + """Test that the eviction loop automatically evicts data when threshold is exceeded.""" + # Set up eviction parameters + eviction_threshold = 10 # Low threshold for testing + eviction_target = 8 # Target to evict down to + interval_secs = 0.1 # Short interval for testing + + # Start the eviction loop + ray.get( + tree_actor.start_eviction_loop.remote( + eviction_threshold, eviction_target, interval_secs + ) + ) + + # Add tenant and insert data over the threshold + ray.get(tree_actor.add_tenants.remote(["tenant_1"], 0)) + ray.get(tree_actor.insert.remote("hello", "tenant_1", 1)) # 5 chars + ray.get( + tree_actor.insert.remote("excess", "tenant_1", 2) + ) # 6 more chars, total: 11 + + # Verify initial count + assert ray.get(tree_actor.getattr.remote("tenant_to_char_count")) == { + "tenant_1": 11 + } + + # Wait for eviction loop to run (interval + small buffer) + await asyncio.sleep(interval_secs + 0.2) + + # Verify data was automatically evicted down to target (8 chars) + # The eviction should have removed 5 chars, so we should be at 6, which is <= 8 + char_count = ray.get(tree_actor.getattr.remote("tenant_to_char_count")) + assert char_count["tenant_1"] == 6 + + async def test_eviction_loop_multiple_tenants( + self, tree_actor: PrefixTreeActor + ) -> None: + """Test that eviction loop evicts from each tenant that exceeds the threshold.""" + # Set up eviction parameters + eviction_threshold = 10 + eviction_target = 8 + interval_secs = 0.1 + + # Start the eviction loop + ray.get( + tree_actor.start_eviction_loop.remote( + eviction_threshold, eviction_target, interval_secs + ) + ) + + # Add two tenants with data over threshold + ray.get(tree_actor.add_tenants.remote(["tenant_1", "tenant_2"], 0)) + ray.get(tree_actor.insert.remote("hello", "tenant_1", 1)) # 5 chars + ray.get( + tree_actor.insert.remote("excess", "tenant_1", 2) + ) # 6 more chars, total: 11 + ray.get(tree_actor.insert.remote("bigstring", "tenant_2", 3)) # 9 chars + ray.get( + tree_actor.insert.remote("more", "tenant_2", 4) + ) # 4 more chars, total: 13 + + # Verify initial counts + initial_count = ray.get(tree_actor.getattr.remote("tenant_to_char_count")) + assert initial_count["tenant_1"] == 11 + assert initial_count["tenant_2"] == 13 + + # Wait for eviction loop to run + await asyncio.sleep(interval_secs + 0.2) + + # Verify both tenants were evicted to target + char_count = ray.get(tree_actor.getattr.remote("tenant_to_char_count")) + + # Tenant 1 should have "hello" evicted, so 11 - 5 = 6 + assert char_count["tenant_1"] == 6 + # Tenant 2 should have "bigstring" evicted, so 13 - 9 = 4 + assert char_count["tenant_2"] == 4 + + async def test_eviction_loop_respects_threshold( + self, tree_actor: PrefixTreeActor + ) -> None: + """Test that eviction loop only evicts tenants that exceed the threshold.""" + # Set up eviction parameters + eviction_threshold = 10 + eviction_target = 8 + interval_secs = 0.1 + + # Start the eviction loop + ray.get( + tree_actor.start_eviction_loop.remote( + eviction_threshold, eviction_target, interval_secs + ) + ) + + # Add two tenants - one over threshold, one under + ray.get(tree_actor.add_tenants.remote(["over_tenant", "under_tenant"], 0)) + ray.get(tree_actor.insert.remote("hello", "over_tenant", 1)) # 5 chars + ray.get( + tree_actor.insert.remote("excess", "over_tenant", 2) + ) # 6 more chars, total: 11 + ray.get(tree_actor.insert.remote("small", "under_tenant", 3)) # 5 chars + + # Verify initial counts + initial_count = ray.get(tree_actor.getattr.remote("tenant_to_char_count")) + assert initial_count["over_tenant"] == 11 + assert initial_count["under_tenant"] == 5 + + # Wait for eviction loop to run + await asyncio.sleep(interval_secs + 0.2) + + # Verify only the tenant over threshold was evicted + char_count = ray.get(tree_actor.getattr.remote("tenant_to_char_count")) + # Tenant 1 should have "hello" evicted, so 11 - 5 = 6 + assert char_count["over_tenant"] == 6 + # Tenant 2 should be unchanged + assert char_count["under_tenant"] == 5 + + async def test_eviction_loop_can_be_started_multiple_times( + self, tree_actor: PrefixTreeActor + ) -> None: + """Test that only the first call to start_eviction_loop starts a new loop.""" + # Call start_eviction_loop multiple times + eviction_task_1 = ray.get(tree_actor.start_eviction_loop.remote(10, 8, 0.1)) + eviction_task_2 = ray.get(tree_actor.start_eviction_loop.remote(10, 0, 0.1)) + assert eviction_task_1 and not eviction_task_2 + + # Add tenant and insert data over the threshold + ray.get(tree_actor.add_tenants.remote(["tenant_1"], 0)) + ray.get(tree_actor.insert.remote("hello", "tenant_1", 1)) # 5 chars + ray.get( + tree_actor.insert.remote("excess", "tenant_1", 2) + ) # 6 more chars, total: 11 + + # Wait for eviction loop to run + await asyncio.sleep(0.3) + + # Verify the first eviction_target_chars is respected. + # Should evict "hello" to bring the char count down from 11 to 6. + + char_count = ray.get(tree_actor.getattr.remote("tenant_to_char_count")) + assert char_count["tenant_1"] == 6 + + if __name__ == "__main__": import sys diff --git a/python/ray/serve/_private/request_router/prefix_aware_router.py b/python/ray/serve/_private/request_router/prefix_aware_router.py new file mode 100644 index 000000000000..cbc7b362bb34 --- /dev/null +++ b/python/ray/serve/_private/request_router/prefix_aware_router.py @@ -0,0 +1,333 @@ +# These imports are used for metrics tracking, will remove for PR +import logging +import time +from typing import ( + List, + Optional, +) + +import ray +from ray.llm._internal.serve.request_router.prefix_aware.prefix_tree import ( + PrefixTreeActor, +) +from ray.serve._private.common import ReplicaID +from ray.serve._private.constants import ( + SERVE_LOGGER_NAME, +) +from ray.serve._private.replica_result import ReplicaResult +from ray.serve._private.request_router import ( + PowerOfTwoChoicesRequestRouter, +) +from ray.serve._private.request_router.common import ( + PendingRequest, +) +from ray.serve._private.request_router.replica_wrapper import ( + RunningReplica, +) +from ray.serve._private.request_router.request_router import ( + LocalityMixin, + MultiplexMixin, + RequestRouter, +) + +logger = logging.getLogger(SERVE_LOGGER_NAME) + + +class PrefixAwarePow2ReplicaRouter(LocalityMixin, MultiplexMixin, RequestRouter): + """Extends the PowerOfTwoChoicesRequestRouter with prefix-matching capabilities. + + This request router optimizes replica selection by considering input text prefixes: + + 1. Mixes between three strategies to balance prefix cache hit rate and load balancing: + - When load is balanced (queue length difference < threshold), it selects replicas + with the highest prefix match rate for the input text + - When load is balanced but match rate is below 10%, it falls back to the smallest tenants + - When load is imbalanced, it uses the default Power of Two selection + + 2. Maintains a prefix tree to track which replicas have processed similar inputs: + - Inserts prompt text into the prefix tree after routing + - Uses this history to inform future routing decisions + + This approach improves performance by routing related requests to the same replicas, + increasing cache locality and reducing overhead for language model inference. + """ + + def __init__( + self, + *args, + imbalanced_threshold=10, + match_rate_threshold=0.1, + do_eviction=False, + eviction_threshold_chars=400_000, + eviction_target_chars=360_000, + eviction_interval_secs=10, + tree_actor=None, + **kwargs, + ): + super().__init__(*args, **kwargs) + if tree_actor is None: + # Use a detached actor to avoid issues with actor lifetime since this is shared between routers + self._tree_actor = PrefixTreeActor.options( + name="LlmPrefixTreeActor", get_if_exists=True, lifetime="detached" + ).remote() + else: + self._tree_actor = tree_actor + + # === Prefix-aware routing logic hyperparameters === + self._imbalanced_threshold = imbalanced_threshold + self._match_rate_threshold = match_rate_threshold + + # === Eviction policy === + self._do_eviction = do_eviction + self._eviction_loop_running = False + self._eviction_threshold_chars = eviction_threshold_chars + # Default eviction_target_chars to eviction_threshold_chars if not specified + self._eviction_target_chars = ( + eviction_target_chars + if eviction_target_chars is not None + else eviction_threshold_chars + ) + self._eviction_interval_secs = eviction_interval_secs + + def _extract_text_from_request(self, pending_request: PendingRequest) -> str: + """Extracts the text content from a pending request for prefix matching. + + Searches through request arguments for either 'messages' or 'prompt' attributes, + then normalizes the content to a single string representation that can be used + for prefix tree operations. + + Args: + pending_request: The request to extract text from + + Returns: + A string containing the prompt text or concatenated message contents + + Raises: + ValueError: If no prompt or messages attribute is found in the request + """ + prompt = None + for arg in pending_request.args: + valid_input_types = ["messages", "prompt"] + for valid_input_type in valid_input_types: + if hasattr(arg, valid_input_type): + prompt = ( + arg.prompt if valid_input_type == "prompt" else arg.messages + ) + break + if prompt is not None: + break + if prompt is None: + raise ValueError( + "No request with message or prompt attribute found in pending_request.args" + ) + + # Convert list of messages to concatenated string + if isinstance(prompt, list): + concatenated_messages = "".join( + msg.get("content", "") for msg in prompt if "content" in msg + ) + return concatenated_messages + else: + return prompt + + async def _prefix_match_best_replicas( + self, + pending_request: Optional[PendingRequest], + candidate_replicas: List[RunningReplica], + ) -> List[RunningReplica]: + """ + Returns a set of candidate replicas, of which the one with the smallest replica queue will be chosen. + 0. Default: same as pow 2 request router, return 2 replicas at random. + 1. If load is balanced, choose replica(s) with highest prefix match rate. If highest hit rate is below 10% or no match found, use replicas with smallest KV cache usage. + 2. If load is imbalanced, use default. + """ + chosen_replica_id_strings = [] + if ( + pending_request is not None + and pending_request.args is not None + and len(pending_request.args) > 0 + ): + input_text = self._extract_text_from_request(pending_request) + if input_text is not None: + # Check for imbalanced load. + highest_queue_len = 0 + lowest_queue_len = float("inf") + not_in_cache: List[ReplicaID] = [] + if self._use_replica_queue_len_cache: + # Populate available queue lens from the cache. + for r in candidate_replicas: + queue_len = self._replica_queue_len_cache.get(r.replica_id) + if queue_len is None or queue_len >= r.max_ongoing_requests: + not_in_cache.append(r) + else: + highest_queue_len = max(highest_queue_len, queue_len) + lowest_queue_len = min(lowest_queue_len, queue_len) + else: + not_in_cache = candidate_replicas + if len(not_in_cache) > 0: + for r, queue_len in await self._probe_queue_lens( + not_in_cache, + 0, + ): + if queue_len is None: + continue + highest_queue_len = max(highest_queue_len, queue_len) + lowest_queue_len = min(lowest_queue_len, queue_len) + + is_imbalanced = ( + highest_queue_len - lowest_queue_len > self._imbalanced_threshold + ) + if not is_imbalanced: + # Convert candidate replica IDs to strings for prefix matching. + candidate_replica_ids_strings = [ + r.replica_id.to_full_id_str() for r in candidate_replicas + ] + (matched_text, matched_tenant_id_strings,) = ray.get( + self._tree_actor.prefix_match.remote( + input_text, candidate_replica_ids_strings + ) + ) + match_rate = len(matched_text) / len(input_text) + if match_rate < self._match_rate_threshold: + smallest_tenants_id_strings = ray.get( + self._tree_actor.get_smallest_tenants.remote() + ) + if ( + smallest_tenants_id_strings is not None + and len(smallest_tenants_id_strings) > 0 + ): + chosen_replica_id_strings = smallest_tenants_id_strings + else: + if ( + matched_tenant_id_strings is not None + and len(matched_tenant_id_strings) > 0 + ): + chosen_replica_id_strings = matched_tenant_id_strings + return [ + [ + self._replicas[ReplicaID.from_full_id_str(chosen_id_string)] + for chosen_id_string in chosen_replica_id_strings + ] + ] + + def on_replica_actor_died(self, replica_id: ReplicaID): + """Drop replica from replica set so it's not considered for future requests.""" + super().on_replica_actor_died(replica_id) + ray.get(self._tree_actor.remove_tenants.remote([replica_id.to_full_id_str()])) + + def update_replicas(self, replicas: List[RunningReplica]): + """Update the set of available replicas to be considered for routing. + + When the set of replicas changes, we may spawn additional routing tasks + if there are pending requests. + """ + # 1) Record the old replica IDs + old_ids = set(self._replica_id_set) + + # 2) Run the default update_replicas logic + super().update_replicas(replicas) + + # 3) Figure out which replicas were added / removed + new_ids = set(self._replica_id_set) + added = new_ids - old_ids + removed = old_ids - new_ids + + # 4) Update the prefix tree with the changes + if added: + added_strings = [rid.to_full_id_str() for rid in added] + ray.get(self._tree_actor.add_tenants.remote(added_strings, time.time())) + + if removed: + removed_strings = [rid.to_full_id_str() for rid in removed] + ray.get(self._tree_actor.remove_tenants.remote(removed_strings)) + + # === Start tasks (if enabled and not already running) === + if self._do_eviction and not self._eviction_loop_running: + ray.get( + self._tree_actor.start_eviction_loop.remote( + self._eviction_threshold_chars, + self._eviction_target_chars, + self._eviction_interval_secs, + ) + ) + self._eviction_loop_running = True + + async def choose_replicas( + self, + candidate_replicas: List[RunningReplica], + pending_request: Optional[PendingRequest] = None, + ) -> List[RunningReplica]: + """One iteration of the power of two choices procedure that chooses + (at most) two random available replicas. + + For multiplexing, this will first attempt to choose replicas that have the + requested model ID for a configured timeout. If no replicas with the matching + model ID are available after that timeout, it will fall back to the regular + procedure. + """ + # Get fallback replicas from PowerOfTwoChoicesRequestRouter + fallback_replicas = await PowerOfTwoChoicesRequestRouter.choose_replicas( + self, + candidate_replicas=candidate_replicas, + pending_request=pending_request, + ) + if pending_request is None or not fallback_replicas: + return fallback_replicas + + if ( + pending_request is not None + and pending_request.metadata.multiplexed_model_id + ): + # Get candidates for multiplexed model ID. + candidate_replica_ids = self.apply_multiplex_routing( + pending_request=pending_request, + ) + else: + # Get candidates for locality preference. + candidate_replica_ids = self.apply_locality_routing( + pending_request=pending_request, + ) + if not candidate_replica_ids: + return fallback_replicas + + # Convert candidate replica IDs to RunningReplica objects. + replica_id_to_replica_map = { + replica.replica_id: replica for replica in candidate_replicas + } + candidate_replicas = [ + replica_id_to_replica_map[candidate_replica_id] + for candidate_replica_id in candidate_replica_ids + ] + chosen_replicas = await self._prefix_match_best_replicas( + pending_request, candidate_replicas + ) + if chosen_replicas[0]: + return chosen_replicas + + return fallback_replicas + + def on_request_routed( + self, + pending_request: PendingRequest, + replica_id: ReplicaID, + result: ReplicaResult, + ): + """Called when a request is routed to a replica. + + This is used as a callback to update the state of the request router + after a response is generated. + """ + # Right now this only inserts the prompt into the prefix tree, not the response (streaming response makes things complicated) + if ( + pending_request is not None + and pending_request.args is not None + and len(pending_request.args) > 0 + ): + input_text = self._extract_text_from_request(pending_request) + if input_text is not None: + # Insert into prefix tree + ray.get( + self._tree_actor.insert.remote( + input_text, replica_id.to_full_id_str(), time.time() + ) + )