diff --git a/docker/Dockerfile b/docker/Dockerfile index 6909cf84b..0246105d2 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -15,7 +15,7 @@ COPY .git/ .git/ # Copy the rest of the application code COPY src/ src/ -ARG INSTALL_OPTIONAL_DEP=semantic_cache +ARG INSTALL_OPTIONAL_DEP=semantic_cache,lmcache ENV INSTALL_OPTIONAL_DEP=${INSTALL_OPTIONAL_DEP} # Install dependencies (use cache, and delete after install, to speed up the build) diff --git a/docs/source/conf.py b/docs/source/conf.py index ee4c5e0ca..63e90deb2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -94,6 +94,7 @@ def add_line(self, line: str, source: str, *lineno: int) -> None: "kubernetes", "prometheus_client", "uhashring", + "lmcache", ] intersphinx_mapping = { diff --git a/examples/kvaware_routing/README.md b/examples/kvaware_routing/README.md new file mode 100644 index 000000000..01fe7175d --- /dev/null +++ b/examples/kvaware_routing/README.md @@ -0,0 +1,98 @@ +# KV Cache Aware Routing Example + +This example demonstrates how to set up and run KV cache aware routing with multiple vLLM servers locally without k8s. k8s native support for KV cache aware routing is coming soon. + +## Prerequisites + +- CUDA-capable GPUs (at least 2 GPUs) +- vLLM installed + +## Setup + +1. Install the routers and LMCache locally: + +```bash +uv pip install -e +git clone https://github.com/LMCache/LMCache.git +uv pip install LMCache +``` + +## Running the Example + +### 1. Start first vLLM Server + +Run the following command to start the first vLLM server on GPU 0: + +```bash +LMCACHE_LOG_LEVEL=DEBUG \ +LMCACHE_USE_EXPERIMENTAL=True \ +LMCACHE_CONFIG_FILE=examples/kvaware_routing/lmcache1.yaml \ +CUDA_VISIBLE_DEVICES=0 \ +vllm serve mistralai/Mistral-7B-Instruct-v0.2 \ + --no-enable-prefix-caching \ + --port 8000 \ + --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_both"}' +``` + +### 2. Start second vLLM Server + +Run the following command to start the second vLLM server on GPU 1: + +```bash +LMCACHE_LOG_LEVEL=DEBUG \ +LMCACHE_USE_EXPERIMENTAL=True \ +LMCACHE_CONFIG_FILE=examples/kvaware_routing/lmcache2.yaml \ +CUDA_VISIBLE_DEVICES=1 \ +vllm serve mistralai/Mistral-7B-Instruct-v0.2 \ + --no-enable-prefix-caching \ + --port 8001 \ + --kv-transfer-config '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_both"}' +``` + +### 3. Start the Router + +Start the router on port 8005: + +```bash +bash run-router.sh 8005 +``` + +### 4. Send Test Requests + +Send a request to the router: + +```bash +bash send_request.sh 8005 +``` + +### 5. Verify Cache Behavior + +Watch the logs on the vLLM server. You should see logs similar to: + +```log +[2025-05-01 22:09:02,807] LMCache DEBUG: Sending 1 messages (worker.py:157:lmcache.experimental.cache_controller.worker) +[2025-05-01 22:09:02,811] LMCache DEBUG: Stored 351 out of total 351 tokens (cache_engine.py:256:lmcache.experimental.cache_engine) +[2025-05-01 22:09:02,811] LMCache DEBUG: Sending 1 messages (worker.py:157:lmcache.experimental.cache_controller.worker) +``` + +### 6. Test Cache Retrieval + +Send the same request again: + +```bash +bash send_request.sh 8005 +``` + +You should now see cache retrieval logs: + +```log +[2025-05-01 22:09:20,704] LMCache INFO: Reqid: cmpl-a76ffbd76f3140ae889f721d137b8412-0, Total tokens 351, LMCache hit tokens: 350, need to load: 350 (vllm_v1_adapter.py:561:lmcache.integration.vllm.vllm_v1_adapter) +[2025-05-01 22:09:20,705] LMCache DEBUG: Scheduled to load 350 tokens for request cmpl-a76ffbd76f3140ae889f721d137b8412-0 (vllm_v1_adapter.py:273:lmcache.integration.vllm.vllm_v1_adapter) +[2025-05-01 22:09:20,716] LMCache DEBUG: Retrieved 351 out of 351 out of total 351 tokens (cache_engine.py:329:lmcache.experimental.cache_engine) +``` + +## Expected Behavior + +- The first request will store the KV cache +- The second request will retrieve the KV cache, demonstrating the cache-aware routing functionality +- The logs will show the number of tokens injected and retrieved from the cache diff --git a/examples/kvaware_routing/lmcache1.yaml b/examples/kvaware_routing/lmcache1.yaml new file mode 100644 index 000000000..3a575ba3f --- /dev/null +++ b/examples/kvaware_routing/lmcache1.yaml @@ -0,0 +1,9 @@ +chunk_size: 256 +local_cpu: True +max_local_cpu_size: 5 + +# cache controller configurations +enable_controller: True +lmcache_instance_id: "http://localhost:8000" +controller_url: "localhost:9001" +lmcache_worker_url: "localhost:8002" diff --git a/examples/kvaware_routing/lmcache2.yaml b/examples/kvaware_routing/lmcache2.yaml new file mode 100644 index 000000000..59c421c82 --- /dev/null +++ b/examples/kvaware_routing/lmcache2.yaml @@ -0,0 +1,9 @@ +chunk_size: 256 +local_cpu: True +max_local_cpu_size: 5 + +# cache controller configurations +enable_controller: True +lmcache_instance_id: "http://localhost:8001" +controller_url: "localhost:9001" +lmcache_worker_url: "localhost:8003" diff --git a/pyproject.toml b/pyproject.toml index 0f75149e5..68bee9f90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,9 @@ semantic_cache = [ "faiss-cpu==1.10.0", "huggingface-hub==0.25.2", # downgrade to 0.25.2 to avoid breaking changes ] +lmcache_cache = [ + "lmcache==0.2.11" +] [build-system] requires = ["setuptools>=68", "setuptools_scm[toml]>=8.0"] @@ -52,5 +55,5 @@ lint = [ ] test = [ "pytest>=8.3.4", - "pytest-asyncio>=0.25.3", + "pytest-asyncio>=0.25.3" ] diff --git a/requirements-test.txt b/requirements-test.txt index 7c1e7006b..71f7facff 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,5 +1,6 @@ faiss-cpu>=1.7.4 huggingface-hub==0.25.2 +lmcache==0.2.11 pytest pytest-asyncio sentence-transformers>=2.2.2 diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index addaba9ea..2ee945f8a 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -135,7 +135,11 @@ def initialize_all(app: FastAPI, args): args.batch_processor, args.file_storage_path, app.state.batch_storage ) - initialize_routing_logic(args.routing_logic, session_key=args.session_key) + initialize_routing_logic( + args.routing_logic, + session_key=args.session_key, + lmcache_controller_port=args.lmcache_controller_port, + ) # Initialize feature gates initialize_feature_gates(args.feature_gates) diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index 0e0569e63..f0206b423 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -99,9 +99,15 @@ def parse_args(): "--routing-logic", type=str, required=True, - choices=["roundrobin", "session"], + choices=["roundrobin", "session", "kvaware"], help="The routing logic to use", ) + parser.add_argument( + "--lmcache-controller-port", + type=int, + default=9000, + help="The port of the LMCache controller.", + ) parser.add_argument( "--session-key", type=str, diff --git a/src/vllm_router/requirements.txt b/src/vllm_router/requirements.txt index 7151639f3..e7e347617 100644 --- a/src/vllm_router/requirements.txt +++ b/src/vllm_router/requirements.txt @@ -2,6 +2,7 @@ aiofiles==24.1.0 fastapi==0.115.8 httpx==0.28.1 kubernetes==32.0.0 +lmcache==0.2.11 numpy==1.26.4 prometheus_client==0.21.1 python-multipart==0.0.20 diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index 4dd0c39b3..21370fad7 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -13,10 +13,15 @@ # limitations under the License. import abc +import asyncio import enum +import threading from typing import Dict, List +import requests from fastapi import Request +from lmcache.experimental.cache_controller import controller_manager +from lmcache.experimental.cache_controller.message import LookupMsg from uhashring import HashRing from vllm_router.log import init_logger @@ -31,6 +36,7 @@ class RoutingLogic(str, enum.Enum): ROUND_ROBIN = "roundrobin" SESSION_BASED = "session" + KVAWARE = "kvaware" class RoutingInterface(metaclass=SingletonABCMeta): @@ -186,6 +192,58 @@ def route_request( return url +class KvawareRouter(RoutingInterface): + def __init__(self, lmcache_controller_port: int): + self.lmcache_controller_port = lmcache_controller_port + self.kv_manager = controller_manager.LMCacheControllerManager( + f"0.0.0.0:{self.lmcache_controller_port}" + ) + self.req_id = 0 + + def start_kv_manager(self): + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=self.loop.run_forever, daemon=True) + self.thread.start() + asyncio.run_coroutine_threadsafe(self.kv_manager.start_all(), self.loop) + + def kv_aware_routing(self, msg: LookupMsg) -> str: + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=self.loop.run_forever, daemon=True) + self.thread.start() + instance_id = asyncio.run_coroutine_threadsafe( + self.kv_manager.handle_orchestration_message(msg), self.loop + ) + self.loop.close() + return instance_id + + async def route_request( + self, + endpoints: List[EndpointInfo], + engine_stats: Dict[str, EngineStats], + request_stats: Dict[str, RequestStats], + request: Request, + request_json: Dict, + ) -> str: + url = endpoints[0].url + "/tokenize" + headers = {"Content-Type": "application/json"} + data = {"model": endpoints[0].model_name, "prompt": request_json["prompt"]} + response = requests.post(url, headers=headers, json=data).json() + token_ids = response["tokens"] + msg = LookupMsg(tokens=token_ids) + instance_id = self.kv_aware_routing(msg) + if instance_id is None: + len_engines = len(endpoints) + chosen = sorted(endpoints, key=lambda e: e.url)[self.req_id % len_engines] + self.req_id += 1 + return chosen.url + else: + self.req_id += 1 + logger.info( + f"Routing request to {instance_id.best_instance_id} found by kvaware router" + ) + return instance_id.best_instance_id + + # Instead of managing a global _global_router, we can define the initialization functions as: def initialize_routing_logic( routing_logic: RoutingLogic, *args, **kwargs @@ -196,6 +254,11 @@ def initialize_routing_logic( elif routing_logic == RoutingLogic.SESSION_BASED: logger.info(f"Initializing session-based routing logic with kwargs: {kwargs}") return SessionRouter(kwargs.get("session_key")) + elif routing_logic == RoutingLogic.KVAWARE: + logger.info("Initializing kvaware routing logic") + router = KvawareRouter(kwargs.get("lmcache_controller_port")) + router.start_kv_manager() + return router else: raise ValueError(f"Invalid routing logic {routing_logic}") @@ -204,7 +267,7 @@ def reconfigure_routing_logic( routing_logic: RoutingLogic, *args, **kwargs ) -> RoutingInterface: # Remove the existing routers from the singleton registry - for cls in (SessionRouter, RoundRobinRouter): + for cls in (SessionRouter, RoundRobinRouter, KvawareRouter): if cls in SingletonABCMeta._instances: del SingletonABCMeta._instances[cls] return initialize_routing_logic(routing_logic, *args, **kwargs) @@ -212,7 +275,7 @@ def reconfigure_routing_logic( def get_routing_logic() -> RoutingInterface: # Look up in our singleton registry which router (if any) has been created. - for cls in (SessionRouter, RoundRobinRouter): + for cls in (SessionRouter, RoundRobinRouter, KvawareRouter): if cls in SingletonABCMeta._instances: return cls() raise ValueError("The global router has not been initialized") diff --git a/src/vllm_router/run-router.sh b/src/vllm_router/run-router.sh index 58643900b..561fe84d0 100755 --- a/src/vllm_router/run-router.sh +++ b/src/vllm_router/run-router.sh @@ -14,18 +14,18 @@ fi # --engine-stats-interval 10 \ # --log-stats -# Use this command when testing with static service discovery +# Use this command when testing with static service discovery for KV cache aware routing python3 -m vllm_router.app --port "$1" \ --service-discovery static \ - --static-backends "http://localhost:9000" \ - --static-models "fake_model_name" \ + --static-backends "http://localhost:8000,http://localhost:8001" \ + --static-models "mistralai/Mistral-7B-Instruct-v0.2,mistralai/Mistral-7B-Instruct-v0.2" \ --log-stats \ --log-stats-interval 10 \ --engine-stats-interval 10 \ --request-stats-window 10 \ --request-stats-window 10 \ - --routing-logic session \ - --session-key "x-user-id" + --routing-logic kvaware \ + --lmcache-controller-port 9001 # Use this command when testing with roundrobin routing logic #python3 router.py --port "$1" \ diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 3a6557404..c0316f876 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -22,6 +22,7 @@ from fastapi.responses import JSONResponse, StreamingResponse from vllm_router.log import init_logger +from vllm_router.routers.routing_logic import KvawareRouter from vllm_router.service_discovery import get_service_discovery from vllm_router.services.request_service.rewriter import ( get_request_rewriter, @@ -207,9 +208,14 @@ async def route_general_request( ) logger.debug(f"Routing request {request_id} for model: {requested_model}") - server_url = request.app.state.router.route_request( - endpoints, engine_stats, request_stats, request - ) + if isinstance(request.app.state.router, KvawareRouter): + server_url = request.app.state.router.route_request( + endpoints, engine_stats, request_stats, request, request_json + ) + else: + server_url = request.app.state.router.route_request( + endpoints, engine_stats, request_stats, request + ) curr_time = time.time() logger.info( f"Routing request {request_id} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}"