Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ COPY .git/ .git/
# Copy the rest of the application code
COPY src/ src/

ARG INSTALL_OPTIONAL_DEP=semantic_cache
ARG INSTALL_OPTIONAL_DEP=semantic_cache,lmcache
ENV INSTALL_OPTIONAL_DEP=${INSTALL_OPTIONAL_DEP}

# Install dependencies (use cache, and delete after install, to speed up the build)
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def add_line(self, line: str, source: str, *lineno: int) -> None:
"kubernetes",
"prometheus_client",
"uhashring",
"lmcache",
]

intersphinx_mapping = {
Expand Down
98 changes: 98 additions & 0 deletions examples/kvaware_routing/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# KV Cache Aware Routing Example

This example demonstrates how to set up and run KV cache aware routing with multiple vLLM servers locally without k8s. k8s native support for KV cache aware routing is coming soon.

## Prerequisites

- CUDA-capable GPUs (at least 2 GPUs)
- vLLM installed

## Setup

1. Install the routers and LMCache locally:

```bash
uv pip install -e <path to production stack>
git clone https://github.com/LMCache/LMCache.git
uv pip install LMCache
```

## Running the Example

### 1. Start first vLLM Server

Run the following command to start the first vLLM server on GPU 0:

```bash
LMCACHE_LOG_LEVEL=DEBUG \
LMCACHE_USE_EXPERIMENTAL=True \
LMCACHE_CONFIG_FILE=examples/kvaware_routing/lmcache1.yaml \
CUDA_VISIBLE_DEVICES=0 \
vllm serve mistralai/Mistral-7B-Instruct-v0.2 \
--no-enable-prefix-caching \
--port 8000 \
--kv-transfer-config '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_both"}'
```

### 2. Start second vLLM Server

Run the following command to start the second vLLM server on GPU 1:

```bash
LMCACHE_LOG_LEVEL=DEBUG \
LMCACHE_USE_EXPERIMENTAL=True \
LMCACHE_CONFIG_FILE=examples/kvaware_routing/lmcache2.yaml \
CUDA_VISIBLE_DEVICES=1 \
vllm serve mistralai/Mistral-7B-Instruct-v0.2 \
--no-enable-prefix-caching \
--port 8001 \
--kv-transfer-config '{"kv_connector":"LMCacheConnectorV1","kv_role":"kv_both"}'
```

### 3. Start the Router

Start the router on port 8005:

```bash
bash run-router.sh 8005
```

### 4. Send Test Requests

Send a request to the router:

```bash
bash send_request.sh 8005
```

### 5. Verify Cache Behavior

Watch the logs on the vLLM server. You should see logs similar to:

```log
[2025-05-01 22:09:02,807] LMCache DEBUG: Sending 1 messages (worker.py:157:lmcache.experimental.cache_controller.worker)
[2025-05-01 22:09:02,811] LMCache DEBUG: Stored 351 out of total 351 tokens (cache_engine.py:256:lmcache.experimental.cache_engine)
[2025-05-01 22:09:02,811] LMCache DEBUG: Sending 1 messages (worker.py:157:lmcache.experimental.cache_controller.worker)
```

### 6. Test Cache Retrieval

Send the same request again:

```bash
bash send_request.sh 8005
```

You should now see cache retrieval logs:

```log
[2025-05-01 22:09:20,704] LMCache INFO: Reqid: cmpl-a76ffbd76f3140ae889f721d137b8412-0, Total tokens 351, LMCache hit tokens: 350, need to load: 350 (vllm_v1_adapter.py:561:lmcache.integration.vllm.vllm_v1_adapter)
[2025-05-01 22:09:20,705] LMCache DEBUG: Scheduled to load 350 tokens for request cmpl-a76ffbd76f3140ae889f721d137b8412-0 (vllm_v1_adapter.py:273:lmcache.integration.vllm.vllm_v1_adapter)
[2025-05-01 22:09:20,716] LMCache DEBUG: Retrieved 351 out of 351 out of total 351 tokens (cache_engine.py:329:lmcache.experimental.cache_engine)
```

## Expected Behavior

- The first request will store the KV cache
- The second request will retrieve the KV cache, demonstrating the cache-aware routing functionality
- The logs will show the number of tokens injected and retrieved from the cache
9 changes: 9 additions & 0 deletions examples/kvaware_routing/lmcache1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
chunk_size: 256
local_cpu: True
max_local_cpu_size: 5

# cache controller configurations
enable_controller: True
lmcache_instance_id: "http://localhost:8000"
controller_url: "localhost:9001"
lmcache_worker_url: "localhost:8002"
9 changes: 9 additions & 0 deletions examples/kvaware_routing/lmcache2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
chunk_size: 256
local_cpu: True
max_local_cpu_size: 5

# cache controller configurations
enable_controller: True
lmcache_instance_id: "http://localhost:8001"
controller_url: "localhost:9001"
lmcache_worker_url: "localhost:8003"
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ semantic_cache = [
"faiss-cpu==1.10.0",
"huggingface-hub==0.25.2", # downgrade to 0.25.2 to avoid breaking changes
]
lmcache_cache = [
"lmcache==0.2.11"
]

[build-system]
requires = ["setuptools>=68", "setuptools_scm[toml]>=8.0"]
Expand All @@ -52,5 +55,5 @@ lint = [
]
test = [
"pytest>=8.3.4",
"pytest-asyncio>=0.25.3",
"pytest-asyncio>=0.25.3"
]
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
faiss-cpu>=1.7.4
huggingface-hub==0.25.2
lmcache==0.2.11
pytest
pytest-asyncio
sentence-transformers>=2.2.2
6 changes: 5 additions & 1 deletion src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ def initialize_all(app: FastAPI, args):
args.batch_processor, args.file_storage_path, app.state.batch_storage
)

initialize_routing_logic(args.routing_logic, session_key=args.session_key)
initialize_routing_logic(
args.routing_logic,
session_key=args.session_key,
lmcache_controller_port=args.lmcache_controller_port,
)

# Initialize feature gates
initialize_feature_gates(args.feature_gates)
Expand Down
8 changes: 7 additions & 1 deletion src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,15 @@ def parse_args():
"--routing-logic",
type=str,
required=True,
choices=["roundrobin", "session"],
choices=["roundrobin", "session", "kvaware"],
help="The routing logic to use",
)
parser.add_argument(
"--lmcache-controller-port",
type=int,
default=9000,
help="The port of the LMCache controller.",
)
parser.add_argument(
"--session-key",
type=str,
Expand Down
1 change: 1 addition & 0 deletions src/vllm_router/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ aiofiles==24.1.0
fastapi==0.115.8
httpx==0.28.1
kubernetes==32.0.0
lmcache==0.2.11
numpy==1.26.4
prometheus_client==0.21.1
python-multipart==0.0.20
Expand Down
67 changes: 65 additions & 2 deletions src/vllm_router/routers/routing_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
# limitations under the License.

import abc
import asyncio
import enum
import threading
from typing import Dict, List

import requests
from fastapi import Request
from lmcache.experimental.cache_controller import controller_manager
from lmcache.experimental.cache_controller.message import LookupMsg
from uhashring import HashRing

from vllm_router.log import init_logger
Expand All @@ -31,6 +36,7 @@
class RoutingLogic(str, enum.Enum):
ROUND_ROBIN = "roundrobin"
SESSION_BASED = "session"
KVAWARE = "kvaware"


class RoutingInterface(metaclass=SingletonABCMeta):
Expand Down Expand Up @@ -186,6 +192,58 @@ def route_request(
return url


class KvawareRouter(RoutingInterface):
def __init__(self, lmcache_controller_port: int):
self.lmcache_controller_port = lmcache_controller_port
self.kv_manager = controller_manager.LMCacheControllerManager(
f"0.0.0.0:{self.lmcache_controller_port}"
)
self.req_id = 0

def start_kv_manager(self):
self.loop = asyncio.new_event_loop()
self.thread = threading.Thread(target=self.loop.run_forever, daemon=True)
self.thread.start()
asyncio.run_coroutine_threadsafe(self.kv_manager.start_all(), self.loop)

def kv_aware_routing(self, msg: LookupMsg) -> str:
self.loop = asyncio.new_event_loop()
self.thread = threading.Thread(target=self.loop.run_forever, daemon=True)
self.thread.start()
instance_id = asyncio.run_coroutine_threadsafe(
self.kv_manager.handle_orchestration_message(msg), self.loop
)
self.loop.close()
return instance_id

async def route_request(
self,
endpoints: List[EndpointInfo],
engine_stats: Dict[str, EngineStats],
request_stats: Dict[str, RequestStats],
request: Request,
request_json: Dict,
) -> str:
url = endpoints[0].url + "/tokenize"
headers = {"Content-Type": "application/json"}
data = {"model": endpoints[0].model_name, "prompt": request_json["prompt"]}
response = requests.post(url, headers=headers, json=data).json()
token_ids = response["tokens"]
msg = LookupMsg(tokens=token_ids)
instance_id = self.kv_aware_routing(msg)
if instance_id is None:
len_engines = len(endpoints)
chosen = sorted(endpoints, key=lambda e: e.url)[self.req_id % len_engines]
self.req_id += 1
return chosen.url
else:
self.req_id += 1
logger.info(
f"Routing request to {instance_id.best_instance_id} found by kvaware router"
)
return instance_id.best_instance_id


# Instead of managing a global _global_router, we can define the initialization functions as:
def initialize_routing_logic(
routing_logic: RoutingLogic, *args, **kwargs
Expand All @@ -196,6 +254,11 @@ def initialize_routing_logic(
elif routing_logic == RoutingLogic.SESSION_BASED:
logger.info(f"Initializing session-based routing logic with kwargs: {kwargs}")
return SessionRouter(kwargs.get("session_key"))
elif routing_logic == RoutingLogic.KVAWARE:
logger.info("Initializing kvaware routing logic")
router = KvawareRouter(kwargs.get("lmcache_controller_port"))
router.start_kv_manager()
return router
else:
raise ValueError(f"Invalid routing logic {routing_logic}")

Expand All @@ -204,15 +267,15 @@ def reconfigure_routing_logic(
routing_logic: RoutingLogic, *args, **kwargs
) -> RoutingInterface:
# Remove the existing routers from the singleton registry
for cls in (SessionRouter, RoundRobinRouter):
for cls in (SessionRouter, RoundRobinRouter, KvawareRouter):
if cls in SingletonABCMeta._instances:
del SingletonABCMeta._instances[cls]
return initialize_routing_logic(routing_logic, *args, **kwargs)


def get_routing_logic() -> RoutingInterface:
# Look up in our singleton registry which router (if any) has been created.
for cls in (SessionRouter, RoundRobinRouter):
for cls in (SessionRouter, RoundRobinRouter, KvawareRouter):
if cls in SingletonABCMeta._instances:
return cls()
raise ValueError("The global router has not been initialized")
10 changes: 5 additions & 5 deletions src/vllm_router/run-router.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@ fi
# --engine-stats-interval 10 \
# --log-stats

# Use this command when testing with static service discovery
# Use this command when testing with static service discovery for KV cache aware routing
python3 -m vllm_router.app --port "$1" \
--service-discovery static \
--static-backends "http://localhost:9000" \
--static-models "fake_model_name" \
--static-backends "http://localhost:8000,http://localhost:8001" \
--static-models "mistralai/Mistral-7B-Instruct-v0.2,mistralai/Mistral-7B-Instruct-v0.2" \
--log-stats \
--log-stats-interval 10 \
--engine-stats-interval 10 \
--request-stats-window 10 \
--request-stats-window 10 \
--routing-logic session \
--session-key "x-user-id"
--routing-logic kvaware \
--lmcache-controller-port 9001

# Use this command when testing with roundrobin routing logic
#python3 router.py --port "$1" \
Expand Down
12 changes: 9 additions & 3 deletions src/vllm_router/services/request_service/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from fastapi.responses import JSONResponse, StreamingResponse

from vllm_router.log import init_logger
from vllm_router.routers.routing_logic import KvawareRouter
from vllm_router.service_discovery import get_service_discovery
from vllm_router.services.request_service.rewriter import (
get_request_rewriter,
Expand Down Expand Up @@ -207,9 +208,14 @@ async def route_general_request(
)

logger.debug(f"Routing request {request_id} for model: {requested_model}")
server_url = request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request
)
if isinstance(request.app.state.router, KvawareRouter):
server_url = request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request, request_json
)
else:
server_url = request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request
)
curr_time = time.time()
logger.info(
f"Routing request {request_id} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}"
Expand Down
Loading