Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
690551c
Pass parameters to custom routers through LLMConfig
eicherseiji Jun 17, 2025
5783bff
Add request_router_kwargs to protobuf
eicherseiji Jun 18, 2025
410fe60
Remove unnecessary lock from eviction loop
eicherseiji Jun 18, 2025
b4c03f4
Add request_router_kwargs to deployment options
eicherseiji Jun 18, 2025
f61543d
Apply suggestions from code review
eicherseiji Jun 18, 2025
4dfeaba
Address code review
eicherseiji Jun 18, 2025
f62375e
Update api docs
eicherseiji Jun 18, 2025
bea9baa
Add comment to Protobuf
eicherseiji Jun 19, 2025
de9561a
Remove prefix tree changes from this PR
eicherseiji Jun 19, 2025
b363565
Create initialize_state() to avoid **kwargs in RequestRouter swallowi…
eicherseiji Jun 23, 2025
cd0a14b
Create RouterConfig
eicherseiji Jun 25, 2025
e08e372
Remove excess whitespace from serve.proto
eicherseiji Jun 25, 2025
3b01a12
Fix java files
eicherseiji Jun 26, 2025
f5444c8
Pickle/unpickle request_router_kwargs
eicherseiji Jun 27, 2025
3c403e2
Add to API .rst, don't serialize bytes, update tests
eicherseiji Jun 30, 2025
13db8dd
Update comments
eicherseiji Jun 30, 2025
b7b3c1d
Fix tests to use RouterConfig, document attributes
eicherseiji Jun 30, 2025
07462aa
Fix bad rebase
eicherseiji Jun 30, 2025
e408b13
Fix test to use RouterConfig
eicherseiji Jul 1, 2025
f73df36
Set router_kwargs to empty bytes in Java
eicherseiji Jul 2, 2025
b415a2f
Fix test to use RouterConfig
eicherseiji Jul 1, 2025
6f77624
Fix ThroughputAwareRequestRouterApp
eicherseiji Jul 2, 2025
8c03deb
Only support request_router_kwargs in Python
eicherseiji Jul 2, 2025
67a8453
Lint
eicherseiji Jul 2, 2025
2776012
Add test
eicherseiji Jul 3, 2025
7e1850e
Apply suggestions from code review
eicherseiji Jul 3, 2025
dcf843a
Improve RouterConfig documentation
eicherseiji Jul 3, 2025
42d08c3
Sphinx format
eicherseiji Jul 3, 2025
1c5b556
Rename RouterConfig -> RequestRouterConfig
eicherseiji Jul 3, 2025
f867e09
Add docstring to initialize_state and move to ray.llm
eicherseiji Jul 3, 2025
0d3938c
Rename serialized_request_router_cls -> _serialized_request_router_cl…
eicherseiji Jul 3, 2025
ffa0ef1
Rename RouterConfig.java -> RequestRouterConfig.java
eicherseiji Jul 3, 2025
e33402e
Update Protobuf field name
eicherseiji Jul 3, 2025
34aab98
Complete renaming
eicherseiji Jul 4, 2025
b4d42e0
Correct API stability
eicherseiji Jul 8, 2025
c3b0d44
Remove PrefixAwarePow2ReplicaRouter.__init__
eicherseiji Jul 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/serve/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ See the [model composition guide](serve-model-composition) for how to update cod
serve.config.gRPCOptions
serve.config.HTTPOptions
serve.config.AutoscalingConfig
serve.config.RequestRouterConfig
```

### Schemas
Expand Down
25 changes: 13 additions & 12 deletions doc/source/serve/doc_code/custom_request_router_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

# __begin_deploy_app_with_uniform_request_router__
from ray import serve
from ray.serve.context import _get_internal_replica_context
from ray.serve.request_router import ReplicaID
import time
from collections import defaultdict
from ray.serve.context import _get_internal_replica_context
from typing import Any, Dict
from ray.serve.config import RequestRouterConfig


@serve.deployment(
request_router_class="custom_request_router:UniformRequestRouter",
request_router_config=RequestRouterConfig(
request_router_class="custom_request_router:UniformRequestRouter",
),
num_replicas=10,
ray_actor_options={"num_cpus": 0},
)
Expand All @@ -30,22 +36,17 @@ async def __call__(self):


# __begin_deploy_app_with_throughput_aware_request_router__
import time
from collections import defaultdict
from ray import serve
from ray.serve.context import _get_internal_replica_context
from typing import Any, Dict


def _time_ms() -> int:
return int(time.time() * 1000)


@serve.deployment(
request_router_class="custom_request_router:ThroughputAwareRequestRouter",
request_router_config=RequestRouterConfig(
request_router_class="custom_request_router:ThroughputAwareRequestRouter",
request_routing_stats_period_s=1,
request_routing_stats_timeout_s=1,
),
num_replicas=3,
request_routing_stats_period_s=1,
request_routing_stats_timeout_s=1,
ray_actor_options={"num_cpus": 0},
)
class ThroughputAwareRequestRouterApp:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,10 @@ public class DeploymentConfig implements Serializable {
*/
private Double healthCheckTimeoutS = Constants.DEFAULT_HEALTH_CHECK_TIMEOUT_S;

/** Frequency at which the controller will record request routing stats. */
private Double requestRoutingStatsPeriodS = Constants.DEFAULT_REQUEST_ROUTING_STATS_PERIOD_S;

/**
* Timeout that the controller will wait for a response from the replica's request routing stats
* before retrying.
*/
private Double requestRoutingStatsTimeoutS = Constants.DEFAULT_REQUEST_ROUTING_STATS_TIMEOUT_S;

private AutoscalingConfig autoscalingConfig;

private RequestRouterConfig routerConfig;

/** This flag is used to let replica know they are deplyed from a different language. */
private Boolean isCrossLanguage = false;

Expand Down Expand Up @@ -150,23 +143,23 @@ public DeploymentConfig setHealthCheckTimeoutS(Double healthCheckTimeoutS) {
}

public Double getRequestRoutingStatsPeriodS() {
return requestRoutingStatsPeriodS;
return routerConfig.getRequestRoutingStatsPeriodS();
}

public DeploymentConfig setRequestRoutingStatsPeriodS(Double requestRoutingStatsPeriodS) {
if (requestRoutingStatsPeriodS != null) {
this.requestRoutingStatsPeriodS = requestRoutingStatsPeriodS;
routerConfig.setRequestRoutingStatsPeriodS(requestRoutingStatsPeriodS);
}
return this;
}

public Double getRequestRoutingStatsTimeoutS() {
return requestRoutingStatsTimeoutS;
return routerConfig.getRequestRoutingStatsTimeoutS();
}

public DeploymentConfig setRequestRoutingStatsTimeoutS(Double requestRoutingStatsTimeoutS) {
if (requestRoutingStatsTimeoutS != null) {
this.requestRoutingStatsTimeoutS = requestRoutingStatsTimeoutS;
routerConfig.setRequestRoutingStatsTimeoutS(requestRoutingStatsTimeoutS);
}
return this;
}
Expand All @@ -180,6 +173,15 @@ public DeploymentConfig setAutoscalingConfig(AutoscalingConfig autoscalingConfig
return this;
}

public RequestRouterConfig getRequestRouterConfig() {
return routerConfig;
}

public DeploymentConfig setRequestRouterConfig(RequestRouterConfig routerConfig) {
this.routerConfig = routerConfig;
return this;
}

public boolean isCrossLanguage() {
return isCrossLanguage;
}
Expand Down Expand Up @@ -230,8 +232,6 @@ public byte[] toProtoBytes() {
.setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS)
.setHealthCheckPeriodS(healthCheckPeriodS)
.setHealthCheckTimeoutS(healthCheckTimeoutS)
.setRequestRoutingStatsPeriodS(requestRoutingStatsPeriodS)
.setRequestRoutingStatsTimeoutS(requestRoutingStatsTimeoutS)
.setIsCrossLanguage(isCrossLanguage)
.setDeploymentLanguage(deploymentLanguage)
.setVersion(version);
Expand All @@ -241,6 +241,9 @@ public byte[] toProtoBytes() {
if (null != autoscalingConfig) {
builder.setAutoscalingConfig(autoscalingConfig.toProto());
}
if (null != routerConfig) {
builder.setRequestRouterConfig(routerConfig.toProto());
}
return builder.build().toByteArray();
}

Expand All @@ -253,8 +256,6 @@ public io.ray.serve.generated.DeploymentConfig toProto() {
.setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS)
.setHealthCheckPeriodS(healthCheckPeriodS)
.setHealthCheckTimeoutS(healthCheckTimeoutS)
.setRequestRoutingStatsPeriodS(requestRoutingStatsPeriodS)
.setRequestRoutingStatsTimeoutS(requestRoutingStatsTimeoutS)
.setIsCrossLanguage(isCrossLanguage)
.setDeploymentLanguage(deploymentLanguage);
if (null != userConfig) {
Expand All @@ -263,6 +264,9 @@ public io.ray.serve.generated.DeploymentConfig toProto() {
if (null != autoscalingConfig) {
builder.setAutoscalingConfig(autoscalingConfig.toProto());
}
if (null != routerConfig) {
builder.setRequestRouterConfig(routerConfig.toProto());
}
return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package io.ray.serve.config;

import io.ray.serve.common.Constants;
import java.io.Serializable;

public class RequestRouterConfig implements Serializable {
/** Frequency at which the controller will record request routing stats. */
private Double requestRoutingStatsPeriodS = Constants.DEFAULT_REQUEST_ROUTING_STATS_PERIOD_S;

/**
* Timeout that the controller waits for a response from the replica's request routing stats
* before retrying.
*/
private Double requestRoutingStatsTimeoutS = Constants.DEFAULT_REQUEST_ROUTING_STATS_TIMEOUT_S;

public Double getRequestRoutingStatsPeriodS() {
return requestRoutingStatsPeriodS;
}

public Double getRequestRoutingStatsTimeoutS() {
return requestRoutingStatsTimeoutS;
}

public void setRequestRoutingStatsPeriodS(Double requestRoutingStatsPeriodS) {
this.requestRoutingStatsPeriodS = requestRoutingStatsPeriodS;
}

public void setRequestRoutingStatsTimeoutS(Double requestRoutingStatsTimeoutS) {
this.requestRoutingStatsTimeoutS = requestRoutingStatsTimeoutS;
}

public io.ray.serve.generated.RequestRouterConfig toProto() {
return io.ray.serve.generated.RequestRouterConfig.newBuilder()
.setRequestRoutingStatsPeriodS(requestRoutingStatsPeriodS)
.setRequestRoutingStatsTimeoutS(requestRoutingStatsTimeoutS)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)

import ray
from ray.actor import ActorHandle
from ray.llm._internal.serve.request_router.prefix_aware.prefix_tree import (
PrefixTreeActor,
)
Expand Down Expand Up @@ -52,27 +53,36 @@ class PrefixAwarePow2ReplicaRouter(LocalityMixin, MultiplexMixin, RequestRouter)
increasing cache locality and reducing overhead for language model inference.
"""

def __init__(
def initialize_state(
self,
*args,
imbalanced_threshold=10,
match_rate_threshold=0.1,
do_eviction=False,
eviction_threshold_chars=400_000,
eviction_target_chars=360_000,
eviction_interval_secs=10,
tree_actor=None,
**kwargs,
imbalanced_threshold: Optional[int] = 10,
match_rate_threshold: Optional[float] = 0.1,
do_eviction: Optional[bool] = False,
eviction_threshold_chars: Optional[int] = 400_000,
eviction_target_chars: Optional[int] = 360_000,
eviction_interval_secs: Optional[int] = 10,
tree_actor: Optional[ActorHandle] = None,
):
super().__init__(*args, **kwargs)
if tree_actor is None:
# Use a detached actor to avoid issues with actor lifetime since this is shared between routers
self._tree_actor = PrefixTreeActor.options(
name="LlmPrefixTreeActor", get_if_exists=True, lifetime="detached"
).remote()
else:
self._tree_actor = tree_actor
"""Initialize the prefix-aware routing state and configuration.

Args:
imbalanced_threshold: Threshold for queue length difference to consider
load balanced. When the difference between replica queue lengths is
less than this value, prefix-aware routing is used.
match_rate_threshold: Minimum prefix match rate (0.0-1.0) required to
use prefix-aware routing. If match rate is below this threshold,
falls back to smallest tenant selection.
do_eviction: Whether to enable automatic eviction of old prefix tree
entries to manage memory usage.
eviction_threshold_chars: Maximum number of characters in the prefix
tree before eviction is triggered.
eviction_target_chars: Target number of characters to reduce the
prefix tree to during eviction.
eviction_interval_secs: Interval in seconds between eviction checks
when eviction is enabled.
tree_actor: The actor to use for the prefix tree in a test environment.
If None, a detached actor will be created/retrieved.
"""
# === Prefix-aware routing logic hyperparameters ===
self._imbalanced_threshold = imbalanced_threshold
self._match_rate_threshold = match_rate_threshold
Expand All @@ -89,6 +99,14 @@ def __init__(
)
self._eviction_interval_secs = eviction_interval_secs

if tree_actor is None:
# Use a detached actor to avoid issues with actor lifetime since this is shared between routers
self._tree_actor = PrefixTreeActor.options(
name="LlmPrefixTreeActor", get_if_exists=True, lifetime="detached"
).remote()
else:
self._tree_actor = tree_actor

def _extract_text_from_request(self, pending_request: PendingRequest) -> str:
"""Extracts the text content from a pending request for prefix matching.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import ray
from ray._common.utils import get_or_create_event_loop
from ray.llm._internal.serve.request_router.prefix_aware.prefix_aware_router import (
PrefixAwarePow2ReplicaRouter,
)
from ray.llm._internal.serve.request_router.prefix_aware.prefix_tree import (
PrefixTreeActor,
)
Expand All @@ -14,9 +17,6 @@
RequestMetadata,
)
from ray.serve._private.request_router.common import PendingRequest
from ray.serve._private.request_router.prefix_aware_router import (
PrefixAwarePow2ReplicaRouter,
)
from ray.serve._private.test_utils import MockTimer
from ray.serve._private.utils import generate_request_id
from ray.serve.tests.unit.test_pow_2_request_router import (
Expand Down Expand Up @@ -48,20 +48,22 @@ async def construct_request_router(loop: asyncio.AbstractEventLoop):
deployment_id=DeploymentID(name="TEST_DEPLOYMENT"),
handle_source=DeploymentHandleSource.REPLICA,
use_replica_queue_len_cache=False,
imbalanced_threshold=params.get("imbalanced_threshold", 10),
match_rate_threshold=params.get("match_rate_threshold", 0.1),
do_eviction=params.get("do_eviction", False),
eviction_threshold_chars=params.get("eviction_threshold_chars"),
eviction_target_chars=params.get("eviction_target_chars"),
eviction_interval_secs=params.get("eviction_interval_secs"),
get_curr_time_s=TIMER.time,
tree_actor=tree_actor,
)
return request_router

request_router = asyncio.new_event_loop().run_until_complete(
construct_request_router(get_or_create_event_loop())
)
request_router.initialize_state(
imbalanced_threshold=params.get("imbalanced_threshold", 10),
match_rate_threshold=params.get("match_rate_threshold", 0.1),
do_eviction=params.get("do_eviction", False),
eviction_threshold_chars=params.get("eviction_threshold_chars"),
eviction_target_chars=params.get("eviction_target_chars"),
eviction_interval_secs=params.get("eviction_interval_secs"),
tree_actor=tree_actor,
)

yield request_router
assert request_router.curr_num_routing_tasks == 0
Expand Down Expand Up @@ -124,7 +126,7 @@ async def test_fallback_when_no_prompt(self, prefix_request_router):

req = fake_pending_request()
for _ in range(10):
chosen = await prefix_request_router.choose_replica_for_request(req)
chosen = await prefix_request_router._choose_replica_for_request(req)
assert chosen == r1

@pytest.mark.asyncio
Expand Down Expand Up @@ -161,7 +163,7 @@ async def test_fallback_when_imbalanced(self, prefix_request_router):

req = fake_pending_request(prompt="hello world")
for _ in range(10):
chosen = await prefix_request_router.choose_replica_for_request(req)
chosen = await prefix_request_router._choose_replica_for_request(req)
# Even though r2 has a higher match rate, it is not chosen because the load is imbalanced
assert chosen == r1

Expand Down Expand Up @@ -199,13 +201,13 @@ async def test_high_match_rate_selects_matching_replica(

prompt_req = fake_pending_request(prompt="Hello world")
for _ in range(10):
chosen = await prefix_request_router.choose_replica_for_request(prompt_req)
chosen = await prefix_request_router._choose_replica_for_request(prompt_req)
assert chosen == r2
chat_req = fake_pending_request(
messages=[{"content": "Hello"}, {"content": " world"}]
)
for _ in range(10):
chosen = await prefix_request_router.choose_replica_for_request(chat_req)
chosen = await prefix_request_router._choose_replica_for_request(chat_req)
assert chosen == r2

@pytest.mark.asyncio
Expand Down Expand Up @@ -240,14 +242,15 @@ async def test_low_match_rate_uses_smallest_tree(self, prefix_request_router):
for _ in range(10):
# Both tenants have 0% match rate, so the smaller tenant (r1) is chosen
assert (
await prefix_request_router.choose_replica_for_request(prompt_req) == r1
await prefix_request_router._choose_replica_for_request(prompt_req)
== r1
)

chat_req = fake_pending_request(messages=[{"content": "z"}])
for _ in range(10):
# Both tenants have 0% match rate, so the smaller tenant (r1) is chosen
assert (
await prefix_request_router.choose_replica_for_request(chat_req) == r1
await prefix_request_router._choose_replica_for_request(chat_req) == r1
)


Expand Down
Loading