diff --git a/doc/source/serve/api/index.md b/doc/source/serve/api/index.md index c63c1d95cf3b..52a67da52c9b 100644 --- a/doc/source/serve/api/index.md +++ b/doc/source/serve/api/index.md @@ -83,6 +83,7 @@ See the [model composition guide](serve-model-composition) for how to update cod serve.config.gRPCOptions serve.config.HTTPOptions serve.config.AutoscalingConfig + serve.config.RequestRouterConfig ``` ### Schemas diff --git a/doc/source/serve/doc_code/custom_request_router_app.py b/doc/source/serve/doc_code/custom_request_router_app.py index e88280f5d1bb..afabaa2d5711 100644 --- a/doc/source/serve/doc_code/custom_request_router_app.py +++ b/doc/source/serve/doc_code/custom_request_router_app.py @@ -2,12 +2,18 @@ # __begin_deploy_app_with_uniform_request_router__ from ray import serve -from ray.serve.context import _get_internal_replica_context from ray.serve.request_router import ReplicaID +import time +from collections import defaultdict +from ray.serve.context import _get_internal_replica_context +from typing import Any, Dict +from ray.serve.config import RequestRouterConfig @serve.deployment( - request_router_class="custom_request_router:UniformRequestRouter", + request_router_config=RequestRouterConfig( + request_router_class="custom_request_router:UniformRequestRouter", + ), num_replicas=10, ray_actor_options={"num_cpus": 0}, ) @@ -30,22 +36,17 @@ async def __call__(self): # __begin_deploy_app_with_throughput_aware_request_router__ -import time -from collections import defaultdict -from ray import serve -from ray.serve.context import _get_internal_replica_context -from typing import Any, Dict - - def _time_ms() -> int: return int(time.time() * 1000) @serve.deployment( - request_router_class="custom_request_router:ThroughputAwareRequestRouter", + request_router_config=RequestRouterConfig( + request_router_class="custom_request_router:ThroughputAwareRequestRouter", + request_routing_stats_period_s=1, + request_routing_stats_timeout_s=1, + ), num_replicas=3, - request_routing_stats_period_s=1, - request_routing_stats_timeout_s=1, ray_actor_options={"num_cpus": 0}, ) class ThroughputAwareRequestRouterApp: diff --git a/java/serve/src/main/java/io/ray/serve/config/DeploymentConfig.java b/java/serve/src/main/java/io/ray/serve/config/DeploymentConfig.java index 2a86cb2d3afc..50e5fa5297f3 100644 --- a/java/serve/src/main/java/io/ray/serve/config/DeploymentConfig.java +++ b/java/serve/src/main/java/io/ray/serve/config/DeploymentConfig.java @@ -52,17 +52,10 @@ public class DeploymentConfig implements Serializable { */ private Double healthCheckTimeoutS = Constants.DEFAULT_HEALTH_CHECK_TIMEOUT_S; - /** Frequency at which the controller will record request routing stats. */ - private Double requestRoutingStatsPeriodS = Constants.DEFAULT_REQUEST_ROUTING_STATS_PERIOD_S; - - /** - * Timeout that the controller will wait for a response from the replica's request routing stats - * before retrying. - */ - private Double requestRoutingStatsTimeoutS = Constants.DEFAULT_REQUEST_ROUTING_STATS_TIMEOUT_S; - private AutoscalingConfig autoscalingConfig; + private RequestRouterConfig routerConfig; + /** This flag is used to let replica know they are deplyed from a different language. */ private Boolean isCrossLanguage = false; @@ -150,23 +143,23 @@ public DeploymentConfig setHealthCheckTimeoutS(Double healthCheckTimeoutS) { } public Double getRequestRoutingStatsPeriodS() { - return requestRoutingStatsPeriodS; + return routerConfig.getRequestRoutingStatsPeriodS(); } public DeploymentConfig setRequestRoutingStatsPeriodS(Double requestRoutingStatsPeriodS) { if (requestRoutingStatsPeriodS != null) { - this.requestRoutingStatsPeriodS = requestRoutingStatsPeriodS; + routerConfig.setRequestRoutingStatsPeriodS(requestRoutingStatsPeriodS); } return this; } public Double getRequestRoutingStatsTimeoutS() { - return requestRoutingStatsTimeoutS; + return routerConfig.getRequestRoutingStatsTimeoutS(); } public DeploymentConfig setRequestRoutingStatsTimeoutS(Double requestRoutingStatsTimeoutS) { if (requestRoutingStatsTimeoutS != null) { - this.requestRoutingStatsTimeoutS = requestRoutingStatsTimeoutS; + routerConfig.setRequestRoutingStatsTimeoutS(requestRoutingStatsTimeoutS); } return this; } @@ -180,6 +173,15 @@ public DeploymentConfig setAutoscalingConfig(AutoscalingConfig autoscalingConfig return this; } + public RequestRouterConfig getRequestRouterConfig() { + return routerConfig; + } + + public DeploymentConfig setRequestRouterConfig(RequestRouterConfig routerConfig) { + this.routerConfig = routerConfig; + return this; + } + public boolean isCrossLanguage() { return isCrossLanguage; } @@ -230,8 +232,6 @@ public byte[] toProtoBytes() { .setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS) .setHealthCheckPeriodS(healthCheckPeriodS) .setHealthCheckTimeoutS(healthCheckTimeoutS) - .setRequestRoutingStatsPeriodS(requestRoutingStatsPeriodS) - .setRequestRoutingStatsTimeoutS(requestRoutingStatsTimeoutS) .setIsCrossLanguage(isCrossLanguage) .setDeploymentLanguage(deploymentLanguage) .setVersion(version); @@ -241,6 +241,9 @@ public byte[] toProtoBytes() { if (null != autoscalingConfig) { builder.setAutoscalingConfig(autoscalingConfig.toProto()); } + if (null != routerConfig) { + builder.setRequestRouterConfig(routerConfig.toProto()); + } return builder.build().toByteArray(); } @@ -253,8 +256,6 @@ public io.ray.serve.generated.DeploymentConfig toProto() { .setGracefulShutdownTimeoutS(gracefulShutdownTimeoutS) .setHealthCheckPeriodS(healthCheckPeriodS) .setHealthCheckTimeoutS(healthCheckTimeoutS) - .setRequestRoutingStatsPeriodS(requestRoutingStatsPeriodS) - .setRequestRoutingStatsTimeoutS(requestRoutingStatsTimeoutS) .setIsCrossLanguage(isCrossLanguage) .setDeploymentLanguage(deploymentLanguage); if (null != userConfig) { @@ -263,6 +264,9 @@ public io.ray.serve.generated.DeploymentConfig toProto() { if (null != autoscalingConfig) { builder.setAutoscalingConfig(autoscalingConfig.toProto()); } + if (null != routerConfig) { + builder.setRequestRouterConfig(routerConfig.toProto()); + } return builder.build(); } diff --git a/java/serve/src/main/java/io/ray/serve/config/RequestRouterConfig.java b/java/serve/src/main/java/io/ray/serve/config/RequestRouterConfig.java new file mode 100644 index 000000000000..10a61d7543b4 --- /dev/null +++ b/java/serve/src/main/java/io/ray/serve/config/RequestRouterConfig.java @@ -0,0 +1,38 @@ +package io.ray.serve.config; + +import io.ray.serve.common.Constants; +import java.io.Serializable; + +public class RequestRouterConfig implements Serializable { + /** Frequency at which the controller will record request routing stats. */ + private Double requestRoutingStatsPeriodS = Constants.DEFAULT_REQUEST_ROUTING_STATS_PERIOD_S; + + /** + * Timeout that the controller waits for a response from the replica's request routing stats + * before retrying. + */ + private Double requestRoutingStatsTimeoutS = Constants.DEFAULT_REQUEST_ROUTING_STATS_TIMEOUT_S; + + public Double getRequestRoutingStatsPeriodS() { + return requestRoutingStatsPeriodS; + } + + public Double getRequestRoutingStatsTimeoutS() { + return requestRoutingStatsTimeoutS; + } + + public void setRequestRoutingStatsPeriodS(Double requestRoutingStatsPeriodS) { + this.requestRoutingStatsPeriodS = requestRoutingStatsPeriodS; + } + + public void setRequestRoutingStatsTimeoutS(Double requestRoutingStatsTimeoutS) { + this.requestRoutingStatsTimeoutS = requestRoutingStatsTimeoutS; + } + + public io.ray.serve.generated.RequestRouterConfig toProto() { + return io.ray.serve.generated.RequestRouterConfig.newBuilder() + .setRequestRoutingStatsPeriodS(requestRoutingStatsPeriodS) + .setRequestRoutingStatsTimeoutS(requestRoutingStatsTimeoutS) + .build(); + } +} diff --git a/python/ray/serve/_private/request_router/prefix_aware_router.py b/python/ray/llm/_internal/serve/request_router/prefix_aware/prefix_aware_router.py similarity index 89% rename from python/ray/serve/_private/request_router/prefix_aware_router.py rename to python/ray/llm/_internal/serve/request_router/prefix_aware/prefix_aware_router.py index cbc7b362bb34..5c6e8b28d504 100644 --- a/python/ray/serve/_private/request_router/prefix_aware_router.py +++ b/python/ray/llm/_internal/serve/request_router/prefix_aware/prefix_aware_router.py @@ -7,6 +7,7 @@ ) import ray +from ray.actor import ActorHandle from ray.llm._internal.serve.request_router.prefix_aware.prefix_tree import ( PrefixTreeActor, ) @@ -52,27 +53,36 @@ class PrefixAwarePow2ReplicaRouter(LocalityMixin, MultiplexMixin, RequestRouter) increasing cache locality and reducing overhead for language model inference. """ - def __init__( + def initialize_state( self, - *args, - imbalanced_threshold=10, - match_rate_threshold=0.1, - do_eviction=False, - eviction_threshold_chars=400_000, - eviction_target_chars=360_000, - eviction_interval_secs=10, - tree_actor=None, - **kwargs, + imbalanced_threshold: Optional[int] = 10, + match_rate_threshold: Optional[float] = 0.1, + do_eviction: Optional[bool] = False, + eviction_threshold_chars: Optional[int] = 400_000, + eviction_target_chars: Optional[int] = 360_000, + eviction_interval_secs: Optional[int] = 10, + tree_actor: Optional[ActorHandle] = None, ): - super().__init__(*args, **kwargs) - if tree_actor is None: - # Use a detached actor to avoid issues with actor lifetime since this is shared between routers - self._tree_actor = PrefixTreeActor.options( - name="LlmPrefixTreeActor", get_if_exists=True, lifetime="detached" - ).remote() - else: - self._tree_actor = tree_actor + """Initialize the prefix-aware routing state and configuration. + Args: + imbalanced_threshold: Threshold for queue length difference to consider + load balanced. When the difference between replica queue lengths is + less than this value, prefix-aware routing is used. + match_rate_threshold: Minimum prefix match rate (0.0-1.0) required to + use prefix-aware routing. If match rate is below this threshold, + falls back to smallest tenant selection. + do_eviction: Whether to enable automatic eviction of old prefix tree + entries to manage memory usage. + eviction_threshold_chars: Maximum number of characters in the prefix + tree before eviction is triggered. + eviction_target_chars: Target number of characters to reduce the + prefix tree to during eviction. + eviction_interval_secs: Interval in seconds between eviction checks + when eviction is enabled. + tree_actor: The actor to use for the prefix tree in a test environment. + If None, a detached actor will be created/retrieved. + """ # === Prefix-aware routing logic hyperparameters === self._imbalanced_threshold = imbalanced_threshold self._match_rate_threshold = match_rate_threshold @@ -89,6 +99,14 @@ def __init__( ) self._eviction_interval_secs = eviction_interval_secs + if tree_actor is None: + # Use a detached actor to avoid issues with actor lifetime since this is shared between routers + self._tree_actor = PrefixTreeActor.options( + name="LlmPrefixTreeActor", get_if_exists=True, lifetime="detached" + ).remote() + else: + self._tree_actor = tree_actor + def _extract_text_from_request(self, pending_request: PendingRequest) -> str: """Extracts the text content from a pending request for prefix matching. diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_aware_request_router.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_aware_request_router.py index 4ecd45f0dceb..e65efe879cc9 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_aware_request_router.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_aware_request_router.py @@ -5,6 +5,9 @@ import ray from ray._common.utils import get_or_create_event_loop +from ray.llm._internal.serve.request_router.prefix_aware.prefix_aware_router import ( + PrefixAwarePow2ReplicaRouter, +) from ray.llm._internal.serve.request_router.prefix_aware.prefix_tree import ( PrefixTreeActor, ) @@ -14,9 +17,6 @@ RequestMetadata, ) from ray.serve._private.request_router.common import PendingRequest -from ray.serve._private.request_router.prefix_aware_router import ( - PrefixAwarePow2ReplicaRouter, -) from ray.serve._private.test_utils import MockTimer from ray.serve._private.utils import generate_request_id from ray.serve.tests.unit.test_pow_2_request_router import ( @@ -48,20 +48,22 @@ async def construct_request_router(loop: asyncio.AbstractEventLoop): deployment_id=DeploymentID(name="TEST_DEPLOYMENT"), handle_source=DeploymentHandleSource.REPLICA, use_replica_queue_len_cache=False, - imbalanced_threshold=params.get("imbalanced_threshold", 10), - match_rate_threshold=params.get("match_rate_threshold", 0.1), - do_eviction=params.get("do_eviction", False), - eviction_threshold_chars=params.get("eviction_threshold_chars"), - eviction_target_chars=params.get("eviction_target_chars"), - eviction_interval_secs=params.get("eviction_interval_secs"), get_curr_time_s=TIMER.time, - tree_actor=tree_actor, ) return request_router request_router = asyncio.new_event_loop().run_until_complete( construct_request_router(get_or_create_event_loop()) ) + request_router.initialize_state( + imbalanced_threshold=params.get("imbalanced_threshold", 10), + match_rate_threshold=params.get("match_rate_threshold", 0.1), + do_eviction=params.get("do_eviction", False), + eviction_threshold_chars=params.get("eviction_threshold_chars"), + eviction_target_chars=params.get("eviction_target_chars"), + eviction_interval_secs=params.get("eviction_interval_secs"), + tree_actor=tree_actor, + ) yield request_router assert request_router.curr_num_routing_tasks == 0 @@ -124,7 +126,7 @@ async def test_fallback_when_no_prompt(self, prefix_request_router): req = fake_pending_request() for _ in range(10): - chosen = await prefix_request_router.choose_replica_for_request(req) + chosen = await prefix_request_router._choose_replica_for_request(req) assert chosen == r1 @pytest.mark.asyncio @@ -161,7 +163,7 @@ async def test_fallback_when_imbalanced(self, prefix_request_router): req = fake_pending_request(prompt="hello world") for _ in range(10): - chosen = await prefix_request_router.choose_replica_for_request(req) + chosen = await prefix_request_router._choose_replica_for_request(req) # Even though r2 has a higher match rate, it is not chosen because the load is imbalanced assert chosen == r1 @@ -199,13 +201,13 @@ async def test_high_match_rate_selects_matching_replica( prompt_req = fake_pending_request(prompt="Hello world") for _ in range(10): - chosen = await prefix_request_router.choose_replica_for_request(prompt_req) + chosen = await prefix_request_router._choose_replica_for_request(prompt_req) assert chosen == r2 chat_req = fake_pending_request( messages=[{"content": "Hello"}, {"content": " world"}] ) for _ in range(10): - chosen = await prefix_request_router.choose_replica_for_request(chat_req) + chosen = await prefix_request_router._choose_replica_for_request(chat_req) assert chosen == r2 @pytest.mark.asyncio @@ -240,14 +242,15 @@ async def test_low_match_rate_uses_smallest_tree(self, prefix_request_router): for _ in range(10): # Both tenants have 0% match rate, so the smaller tenant (r1) is chosen assert ( - await prefix_request_router.choose_replica_for_request(prompt_req) == r1 + await prefix_request_router._choose_replica_for_request(prompt_req) + == r1 ) chat_req = fake_pending_request(messages=[{"content": "z"}]) for _ in range(10): # Both tenants have 0% match rate, so the smaller tenant (r1) is chosen assert ( - await prefix_request_router.choose_replica_for_request(chat_req) == r1 + await prefix_request_router._choose_replica_for_request(chat_req) == r1 ) diff --git a/python/ray/serve/_private/config.py b/python/ray/serve/_private/config.py index 6a5ed3655b7f..3293f5557ee4 100644 --- a/python/ray/serve/_private/config.py +++ b/python/ray/serve/_private/config.py @@ -13,10 +13,9 @@ NonNegativeInt, PositiveFloat, PositiveInt, - root_validator, validator, ) -from ray._common.utils import import_attr, resources_from_ray_options +from ray._common.utils import resources_from_ray_options from ray._private import ray_option_utils from ray._private.serialization import pickle_dumps from ray.serve._private.constants import ( @@ -25,13 +24,10 @@ DEFAULT_HEALTH_CHECK_PERIOD_S, DEFAULT_HEALTH_CHECK_TIMEOUT_S, DEFAULT_MAX_ONGOING_REQUESTS, - DEFAULT_REQUEST_ROUTER_PATH, - DEFAULT_REQUEST_ROUTING_STATS_PERIOD_S, - DEFAULT_REQUEST_ROUTING_STATS_TIMEOUT_S, MAX_REPLICAS_PER_NODE_MAX_VALUE, ) from ray.serve._private.utils import DEFAULT, DeploymentOptionUpdateType -from ray.serve.config import AutoscalingConfig +from ray.serve.config import AutoscalingConfig, RequestRouterConfig from ray.serve.generated.serve_pb2 import ( AutoscalingConfig as AutoscalingConfigProto, DeploymentConfig as DeploymentConfigProto, @@ -39,6 +35,7 @@ EncodingType as EncodingTypeProto, LoggingConfig as LoggingConfigProto, ReplicaConfig as ReplicaConfigProto, + RequestRouterConfig as RequestRouterConfigProto, ) from ray.util.placement_group import validate_placement_group @@ -121,14 +118,11 @@ class DeploymentConfig(BaseModel): health_check_timeout_s: Timeout that the controller waits for a response from the replica's health check before marking it unhealthy. - request_routing_stats_period_s: Frequency at which the controller - record request routing stats. - request_routing_stats_timeout_s: Timeout that the controller waits - for a response from the replica's record routing stats call. autoscaling_config: Autoscaling configuration. logging_config: Configuration for deployment logs. user_configured_option_names: The names of options manually configured by the user. + request_router_config: Configuration for deployment request router. """ num_replicas: Optional[NonNegativeInt] = Field( @@ -163,19 +157,16 @@ class DeploymentConfig(BaseModel): default=DEFAULT_HEALTH_CHECK_TIMEOUT_S, update_type=DeploymentOptionUpdateType.NeedsReconfigure, ) - request_routing_stats_period_s: PositiveFloat = Field( - default=DEFAULT_REQUEST_ROUTING_STATS_PERIOD_S, - update_type=DeploymentOptionUpdateType.NeedsReconfigure, - ) - request_routing_stats_timeout_s: PositiveFloat = Field( - default=DEFAULT_REQUEST_ROUTING_STATS_TIMEOUT_S, - update_type=DeploymentOptionUpdateType.NeedsReconfigure, - ) autoscaling_config: Optional[AutoscalingConfig] = Field( default=None, update_type=DeploymentOptionUpdateType.NeedsActorReconfigure ) + request_router_config: RequestRouterConfig = Field( + default=RequestRouterConfig(), + update_type=DeploymentOptionUpdateType.NeedsActorReconfigure, + ) + # This flag is used to let replica know they are deployed from # a different language. is_cross_language: bool = False @@ -197,14 +188,6 @@ class DeploymentConfig(BaseModel): # Contains the names of deployment options manually set by the user user_configured_option_names: Set[str] = set() - # Cloudpickled request router class. - serialized_request_router_cls: bytes = Field(default=b"") - - # Custom request router config. Defaults to the power of two request router. - request_router_class: Union[str, Callable] = Field( - default=DEFAULT_REQUEST_ROUTER_PATH - ) - class Config: validate_assignment = True arbitrary_types_allowed = True @@ -248,33 +231,6 @@ def validate_max_queued_requests(cls, v): return v - @root_validator - def import_and_serialize_request_router_cls(cls, values) -> Dict[str, Any]: - """Import and serialize request router class with cloudpickle. - - Import the request router if it's passed in as a string import path. - Then cloudpickle the request router and set to - `serialized_request_router_cls`. - """ - request_router_class = values.get("request_router_class") - if isinstance(request_router_class, Callable): - request_router_class = ( - f"{request_router_class.__module__}.{request_router_class.__name__}" - ) - - request_router_path = request_router_class or DEFAULT_REQUEST_ROUTER_PATH - request_router_class = import_attr(request_router_path) - - values["serialized_request_router_cls"] = cloudpickle.dumps( - request_router_class - ) - values["request_router_class"] = request_router_path - return values - - def get_request_router_class(self) -> Callable: - """Deserialize request router from cloudpickled bytes.""" - return cloudpickle.loads(self.serialized_request_router_cls) - def needs_pickle(self): return _needs_pickle(self.deployment_language, self.is_cross_language) @@ -287,12 +243,29 @@ def to_proto(self): data["autoscaling_config"] = AutoscalingConfigProto( **data["autoscaling_config"] ) + if data.get("request_router_config"): + router_kwargs = data["request_router_config"].get("request_router_kwargs") + if router_kwargs is not None: + if not router_kwargs: + data["request_router_config"]["request_router_kwargs"] = b"" + elif self.needs_pickle(): + # Protobuf requires bytes, so we need to pickle + data["request_router_config"][ + "request_router_kwargs" + ] = cloudpickle.dumps(router_kwargs) + else: + raise ValueError( + "Non-empty request_router_kwargs not supported" + f"for cross-language deployments. Got: {router_kwargs}" + ) + data["request_router_config"] = RequestRouterConfigProto( + **data["request_router_config"] + ) if data.get("logging_config"): if "encoding" in data["logging_config"]: data["logging_config"]["encoding"] = EncodingTypeProto.Value( data["logging_config"]["encoding"] ) - data["logging_config"] = LoggingConfigProto(**data["logging_config"]) data["user_configured_option_names"] = list( data["user_configured_option_names"] @@ -305,23 +278,45 @@ def to_proto_bytes(self): @classmethod def from_proto(cls, proto: DeploymentConfigProto): data = _proto_to_dict(proto) + deployment_language = ( + data["deployment_language"] + if "deployment_language" in data + else DeploymentLanguage.PYTHON + ) + is_cross_language = ( + data["is_cross_language"] if "is_cross_language" in data else False + ) + needs_pickle = _needs_pickle(deployment_language, is_cross_language) if "user_config" in data: if data["user_config"] != b"": - deployment_language = ( - data["deployment_language"] - if "deployment_language" in data - else DeploymentLanguage.PYTHON - ) - is_cross_language = ( - data["is_cross_language"] if "is_cross_language" in data else False - ) - needs_pickle = _needs_pickle(deployment_language, is_cross_language) if needs_pickle: data["user_config"] = cloudpickle.loads(proto.user_config) else: data["user_config"] = proto.user_config else: data["user_config"] = None + if "request_router_config" in data: + if "request_router_kwargs" in data["request_router_config"]: + request_router_kwargs = data["request_router_config"][ + "request_router_kwargs" + ] + if request_router_kwargs != b"": + if needs_pickle: + data["request_router_config"][ + "request_router_kwargs" + ] = cloudpickle.loads( + proto.request_router_config.request_router_kwargs + ) + else: + data["request_router_config"][ + "request_router_kwargs" + ] = proto.request_router_config.request_router_kwargs + else: + data["request_router_config"]["request_router_kwargs"] = {} + + data["request_router_config"] = RequestRouterConfig( + **data["request_router_config"] + ) if "autoscaling_config" in data: if not data["autoscaling_config"].get("upscale_smoothing_factor"): data["autoscaling_config"]["upscale_smoothing_factor"] = None diff --git a/python/ray/serve/_private/deployment_state.py b/python/ray/serve/_private/deployment_state.py index 1d5cdbe07630..d30c5ccee6fa 100644 --- a/python/ray/serve/_private/deployment_state.py +++ b/python/ray/serve/_private/deployment_state.py @@ -359,11 +359,15 @@ def health_check_timeout_s(self) -> float: @property def request_routing_stats_period_s(self) -> float: - return self.deployment_config.request_routing_stats_period_s + return ( + self.deployment_config.request_router_config.request_routing_stats_period_s + ) @property def request_routing_stats_timeout_s(self) -> float: - return self.deployment_config.request_routing_stats_timeout_s + return ( + self.deployment_config.request_router_config.request_routing_stats_timeout_s + ) @property def pid(self) -> Optional[int]: diff --git a/python/ray/serve/_private/request_router/request_router.py b/python/ray/serve/_private/request_router/request_router.py index 4b10a37537d0..ebfc2bc6a181 100644 --- a/python/ray/serve/_private/request_router/request_router.py +++ b/python/ray/serve/_private/request_router/request_router.py @@ -534,6 +534,13 @@ def __init__( ) self.num_routing_tasks_in_backoff_gauge.set(self.num_routing_tasks_in_backoff) + def initialize_state(self, **kwargs): + """ + Initialize the state of the request router. Called by the Ray Serve framework with the + contents of `RequestRouter.request_router_kwargs`. + """ + pass + @property def _event_loop(self) -> asyncio.AbstractEventLoop: if self._lazily_fetched_loop is None: diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 32d10c8549eb..032fa9744958 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -377,6 +377,7 @@ def __init__( prefer_local_node_routing: bool, resolve_request_arg_func: Coroutine = resolve_deployment_response, request_router_class: Optional[Callable] = None, + request_router_kwargs: Optional[Dict[str, Any]] = None, request_router: Optional[RequestRouter] = None, _request_router_initialized_event: Optional[asyncio.Event] = None, ): @@ -391,6 +392,9 @@ def __init__( self._handle_source = handle_source self._event_loop = event_loop self._request_router_class = request_router_class + self._request_router_kwargs = ( + request_router_kwargs if request_router_kwargs else {} + ) self._enable_strict_max_ongoing_requests = enable_strict_max_ongoing_requests self._node_id = node_id self._availability_zone = availability_zone @@ -503,6 +507,7 @@ def request_router(self) -> Optional[RequestRouter]: prefer_local_az_routing=RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING, self_availability_zone=self._availability_zone, ) + request_router.initialize_state(**(self._request_router_kwargs)) # Populate the running replicas if they are already available. if self._running_replicas is not None: @@ -537,7 +542,12 @@ def update_deployment_targets(self, deployment_target_info: DeploymentTargetInfo self._running_replicas_populated = True def update_deployment_config(self, deployment_config: DeploymentConfig): - self._request_router_class = deployment_config.get_request_router_class() + self._request_router_class = ( + deployment_config.request_router_config.get_request_router_class() + ) + self._request_router_kwargs = ( + deployment_config.request_router_config.request_router_kwargs + ) self._metrics_manager.update_deployment_config( deployment_config, curr_num_replicas=len(self.request_router.curr_replicas), diff --git a/python/ray/serve/_private/version.py b/python/ray/serve/_private/version.py index 1068d5b71cb2..9242dfc928e9 100644 --- a/python/ray/serve/_private/version.py +++ b/python/ray/serve/_private/version.py @@ -186,6 +186,13 @@ def _get_serialized_options( elif isinstance(reconfigure_dict[option_name], BaseModel): reconfigure_dict[option_name] = reconfigure_dict[option_name].dict() + # Can't serialize bytes. The request router class is already + # included in the serialized config as request_router_class. + if "request_router_config" in reconfigure_dict: + reconfigure_dict["request_router_config"].pop( + "_serialized_request_router_cls", None + ) + if ( isinstance(self.deployment_config.user_config, bytes) and "user_config" in reconfigure_dict diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index aea80fe88ce6..865d5cedbba1 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -28,7 +28,6 @@ ) from ray.serve._private.local_testing_mode import make_local_deployment_handle from ray.serve._private.logging_utils import configure_component_logger -from ray.serve._private.request_router.request_router import RequestRouter from ray.serve._private.usage import ServeUsageTag from ray.serve._private.utils import ( DEFAULT, @@ -43,6 +42,7 @@ DeploymentMode, HTTPOptions, ProxyLocation, + RequestRouterConfig, gRPCOptions, ) from ray.serve.context import ( @@ -55,11 +55,7 @@ from ray.serve.exceptions import RayServeException from ray.serve.handle import DeploymentHandle from ray.serve.multiplex import _ModelMultiplexWrapper -from ray.serve.schema import ( - LoggingConfig, - ServeInstanceDetails, - ServeStatus, -) +from ray.serve.schema import LoggingConfig, ServeInstanceDetails, ServeStatus from ray.util.annotations import DeveloperAPI, PublicAPI from ray.serve._private import api as _private_api # isort:skip @@ -338,9 +334,9 @@ def deployment( health_check_period_s: Default[float] = DEFAULT.VALUE, health_check_timeout_s: Default[float] = DEFAULT.VALUE, logging_config: Default[Union[Dict, LoggingConfig, None]] = DEFAULT.VALUE, - request_router_class: Default[Union[str, RequestRouter, None]] = DEFAULT.VALUE, - request_routing_stats_period_s: Default[float] = DEFAULT.VALUE, - request_routing_stats_timeout_s: Default[float] = DEFAULT.VALUE, + request_router_config: Default[ + Union[Dict, RequestRouterConfig, None] + ] = DEFAULT.VALUE, ) -> Callable[[Callable], Deployment]: """Decorator that converts a Python class to a `Deployment`. @@ -405,20 +401,7 @@ class MyDeployment: check method to return before considering it as failed. Defaults to 30s. logging_config: Logging config options for the deployment. If provided, the config will be used to set up the Serve logger on the deployment. - request_router_class: The class of the request router used for this - deployment. This can be a string or a class. All the deployment - handle created for this deployment will use the routing policy - defined by the request router. Default to Serve's - PowerOfTwoChoicesRequestRouter. - request_routing_stats_period_s: Duration between record scheduling stats - calls for the replica. Defaults to 10s. The health check is by default a - no-op Actor call to the replica, but you can define your own request - scheduling stats using the "record_scheduling_stats" method in your - deployment. - request_routing_stats_timeout_s: Duration in seconds, that replicas wait for - a request scheduling stats method to return before considering it as failed. - Defaults to 30s. - + request_router_config: Config for the request router used for this deployment. Returns: `Deployment` """ @@ -483,14 +466,10 @@ class MyDeployment: health_check_period_s=health_check_period_s, health_check_timeout_s=health_check_timeout_s, logging_config=logging_config, - request_routing_stats_period_s=request_routing_stats_period_s, - request_routing_stats_timeout_s=request_routing_stats_timeout_s, + request_router_config=request_router_config, ) deployment_config.user_configured_option_names = set(user_configured_option_names) - if request_router_class is not DEFAULT.VALUE: - deployment_config.request_router_class = request_router_class - def decorator(_func_or_class): replica_config = ReplicaConfig.create( _func_or_class, diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index c19b537de72a..1386a99b48f4 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -1,7 +1,8 @@ +import json import logging import warnings from enum import Enum -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from ray import cloudpickle from ray._common.pydantic_compat import ( @@ -20,6 +21,9 @@ DEFAULT_GRPC_PORT, DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT, + DEFAULT_REQUEST_ROUTER_PATH, + DEFAULT_REQUEST_ROUTING_STATS_PERIOD_S, + DEFAULT_REQUEST_ROUTING_STATS_TIMEOUT_S, DEFAULT_TARGET_ONGOING_REQUESTS, DEFAULT_UVICORN_KEEP_ALIVE_TIMEOUT_S, SERVE_LOGGER_NAME, @@ -29,6 +33,132 @@ logger = logging.getLogger(SERVE_LOGGER_NAME) +@PublicAPI(stability="alpha") +class RequestRouterConfig(BaseModel): + """Config for the Serve request router. + + This class configures how Ray Serve routes requests to deployment replicas. The router is + responsible for selecting which replica should handle each incoming request based on the + configured routing policy. You can customize the routing behavior by specifying a custom + request router class and providing configuration parameters. + + The router also manages periodic health checks and scheduling statistics collection from + replicas to make informed routing decisions. + + Example: + .. code-block:: python + + from ray.serve.config import RequestRouterConfig, DeploymentConfig + from ray import serve + + # Use default router with custom stats collection interval + request_router_config = RequestRouterConfig( + request_routing_stats_period_s=5.0, + request_routing_stats_timeout_s=15.0 + ) + + # Use custom router class + request_router_config = RequestRouterConfig( + request_router_class="ray.llm._internal.serve.request_router.prefix_aware.prefix_aware_router.PrefixAwarePow2ReplicaRouter", + request_router_kwargs={"imbalanced_threshold": 20} + ) + deployment_config = DeploymentConfig( + request_router_config=request_router_config + ) + deployment = serve.deploy( + "my_deployment", + deployment_config=deployment_config + ) + """ + + _serialized_request_router_cls: bytes = PrivateAttr(default=b"") + + request_router_class: Union[str, Callable] = Field( + default=DEFAULT_REQUEST_ROUTER_PATH, + description=( + "The class of the request router that Ray Serve uses for this deployment. This value can be " + "a string or a class. All the deployment handles that you create for this " + "deployment use the routing policy defined by the request router. " + "Default to Serve's PowerOfTwoChoicesRequestRouter." + ), + ) + request_router_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description=( + "Keyword arguments that Ray Serve passes to the request router class " + "initialize_state method." + ), + ) + + request_routing_stats_period_s: PositiveFloat = Field( + default=DEFAULT_REQUEST_ROUTING_STATS_PERIOD_S, + description=( + "Duration between record scheduling stats calls for the replica. " + "Defaults to 10s. The health check is by default a no-op Actor call " + "to the replica, but you can define your own request scheduling stats " + "using the 'record_scheduling_stats' method in your deployment." + ), + ) + + request_routing_stats_timeout_s: PositiveFloat = Field( + default=DEFAULT_REQUEST_ROUTING_STATS_TIMEOUT_S, + description=( + "Duration in seconds, that replicas wait for a request scheduling " + "stats method to return before considering it as failed. Defaults to 30s." + ), + ) + + @validator("request_router_kwargs", always=True) + def request_router_kwargs_json_serializable(cls, v): + if isinstance(v, bytes): + return v + if v is not None: + try: + json.dumps(v) + except TypeError as e: + raise ValueError( + f"request_router_kwargs is not JSON-serializable: {str(e)}." + ) + + return v + + def __init__(self, **kwargs: dict[str, Any]): + """Initialize RequestRouterConfig with the given parameters. + + Needed to serialize the request router class since validators are not called + for attributes that begin with an underscore. + + Args: + **kwargs: Keyword arguments to pass to BaseModel. + """ + super().__init__(**kwargs) + self._serialize_request_router_cls() + + def _serialize_request_router_cls(self) -> None: + """Import and serialize request router class with cloudpickle. + + Import the request router if you pass it in as a string import path. + Then cloudpickle the request router and set to + `_serialized_request_router_cls`. + """ + request_router_class = self.request_router_class + if isinstance(request_router_class, Callable): + request_router_class = ( + f"{request_router_class.__module__}.{request_router_class.__name__}" + ) + + request_router_path = request_router_class or DEFAULT_REQUEST_ROUTER_PATH + request_router_class = import_attr(request_router_path) + + self._serialized_request_router_cls = cloudpickle.dumps(request_router_class) + # Update the request_router_class field to be the string path + self.request_router_class = request_router_path + + def get_request_router_class(self) -> Callable: + """Deserialize the request router from cloudpickled bytes.""" + return cloudpickle.loads(self._serialized_request_router_cls) + + @PublicAPI(stability="stable") class AutoscalingConfig(BaseModel): """Config for the Serve Autoscaler.""" @@ -43,13 +173,17 @@ class AutoscalingConfig(BaseModel): target_ongoing_requests: PositiveFloat = DEFAULT_TARGET_ONGOING_REQUESTS - # How often to scrape for metrics - metrics_interval_s: PositiveFloat = 10.0 - # Time window to average over for metrics. - look_back_period_s: PositiveFloat = 30.0 + metrics_interval_s: PositiveFloat = Field( + default=10.0, description="How often to scrape for metrics." + ) + look_back_period_s: PositiveFloat = Field( + default=30.0, description="Time window to average over for metrics." + ) - # DEPRECATED - smoothing_factor: PositiveFloat = 1.0 + smoothing_factor: PositiveFloat = Field( + default=1.0, + description="[DEPRECATED] Smoothing factor for autoscaling decisions.", + ) # DEPRECATED: replaced by `downscaling_factor` upscale_smoothing_factor: Optional[PositiveFloat] = Field( default=None, description="[DEPRECATED] Please use `upscaling_factor` instead." @@ -60,16 +194,23 @@ class AutoscalingConfig(BaseModel): description="[DEPRECATED] Please use `downscaling_factor` instead.", ) - # Multiplicative "gain" factor to limit scaling decisions - upscaling_factor: Optional[PositiveFloat] = None - downscaling_factor: Optional[PositiveFloat] = None + upscaling_factor: Optional[PositiveFloat] = Field( + default=None, + description='Multiplicative "gain" factor to limit upscaling decisions.', + ) + downscaling_factor: Optional[PositiveFloat] = Field( + default=None, + description='Multiplicative "gain" factor to limit downscaling decisions.', + ) # How frequently to make autoscaling decisions # loop_period_s: float = CONTROL_LOOP_PERIOD_S - # How long to wait before scaling down replicas - downscale_delay_s: NonNegativeFloat = 600.0 - # How long to wait before scaling up replicas - upscale_delay_s: NonNegativeFloat = 30.0 + downscale_delay_s: NonNegativeFloat = Field( + default=600.0, description="How long to wait before scaling down replicas." + ) + upscale_delay_s: NonNegativeFloat = Field( + default=30.0, description="How long to wait before scaling up replicas." + ) # Cloudpickled policy definition. _serialized_policy_def: bytes = PrivateAttr(default=b"") diff --git a/python/ray/serve/deployment.py b/python/ray/serve/deployment.py index 68139f5b5bb4..5487ad4d0afc 100644 --- a/python/ray/serve/deployment.py +++ b/python/ray/serve/deployment.py @@ -7,10 +7,10 @@ from ray.serve._private.config import ( DeploymentConfig, ReplicaConfig, + RequestRouterConfig, handle_num_replicas_auto, ) from ray.serve._private.constants import SERVE_LOGGER_NAME -from ray.serve._private.request_router.request_router import RequestRouter from ray.serve._private.usage import ServeUsageTag from ray.serve._private.utils import DEFAULT, Default from ray.serve.config import AutoscalingConfig @@ -237,9 +237,9 @@ def options( health_check_period_s: Default[float] = DEFAULT.VALUE, health_check_timeout_s: Default[float] = DEFAULT.VALUE, logging_config: Default[Union[Dict, LoggingConfig, None]] = DEFAULT.VALUE, - request_router_class: Default[Union[str, RequestRouter, None]] = DEFAULT.VALUE, - request_routing_stats_period_s: Default[float] = DEFAULT.VALUE, - request_routing_stats_timeout_s: Default[float] = DEFAULT.VALUE, + request_router_config: Default[ + Union[Dict, RequestRouterConfig, None] + ] = DEFAULT.VALUE, _init_args: Default[Tuple[Any]] = DEFAULT.VALUE, _init_kwargs: Default[Dict[Any, Any]] = DEFAULT.VALUE, _internal: bool = False, @@ -351,6 +351,9 @@ def options( if autoscaling_config is not DEFAULT.VALUE: new_deployment_config.autoscaling_config = autoscaling_config + if request_router_config is not DEFAULT.VALUE: + new_deployment_config.request_router_config = request_router_config + if graceful_shutdown_wait_loop_s is not DEFAULT.VALUE: new_deployment_config.graceful_shutdown_wait_loop_s = ( graceful_shutdown_wait_loop_s @@ -372,19 +375,6 @@ def options( logging_config = logging_config.dict() new_deployment_config.logging_config = logging_config - if request_router_class is not DEFAULT.VALUE: - new_deployment_config.request_router_class = request_router_class - - if request_routing_stats_period_s is not DEFAULT.VALUE: - new_deployment_config.request_routing_stats_period_s = ( - request_routing_stats_period_s - ) - - if request_routing_stats_timeout_s is not DEFAULT.VALUE: - new_deployment_config.request_routing_stats_timeout_s = ( - request_routing_stats_timeout_s - ) - new_replica_config = ReplicaConfig.create( func_or_class, init_args=_init_args, @@ -453,8 +443,7 @@ def deployment_to_schema(d: Deployment) -> DeploymentSchema: "placement_group_bundles": d._replica_config.placement_group_bundles, "max_replicas_per_node": d._replica_config.max_replicas_per_node, "logging_config": d._deployment_config.logging_config, - "request_routing_stats_period_s": d._deployment_config.request_routing_stats_period_s, - "request_routing_stats_timeout_s": d._deployment_config.request_routing_stats_timeout_s, + "request_router_config": d._deployment_config.request_router_config, } # Let non-user-configured options be set to defaults. If the schema @@ -515,8 +504,7 @@ def schema_to_deployment(s: DeploymentSchema) -> Deployment: health_check_period_s=s.health_check_period_s, health_check_timeout_s=s.health_check_timeout_s, logging_config=s.logging_config, - request_routing_stats_period_s=s.request_routing_stats_period_s, - request_routing_stats_timeout_s=s.request_routing_stats_timeout_s, + request_router_config=s.request_router_config, ) deployment_config.user_configured_option_names = ( s._get_user_configured_option_names() diff --git a/python/ray/serve/schema.py b/python/ray/serve/schema.py index 80904f875765..607097fee8a7 100644 --- a/python/ray/serve/schema.py +++ b/python/ray/serve/schema.py @@ -33,7 +33,7 @@ ) from ray.serve._private.deployment_info import DeploymentInfo from ray.serve._private.utils import DEFAULT -from ray.serve.config import ProxyLocation +from ray.serve.config import ProxyLocation, RequestRouterConfig from ray.util.annotations import PublicAPI # Shared amongst multiple schemas. @@ -405,25 +405,9 @@ class DeploymentSchema(BaseModel, allow_population_by_field_name=True): default=DEFAULT.VALUE, description="Logging config for configuring serve deployment logs.", ) - request_router_class: str = Field( + request_router_config: Union[Dict, RequestRouterConfig] = Field( default=DEFAULT.VALUE, - description="The path pointing to the custom request router class to use for this deployment.", - ) - request_routing_stats_period_s: float = Field( - default=DEFAULT.VALUE, - description=( - "Frequency at which the controller will record routing stats " - "replicas. Uses a default if null." - ), - gt=0, - ) - request_routing_stats_timeout_s: float = Field( - default=DEFAULT.VALUE, - description=( - "Timeout that the controller will wait for a response " - "from the replica's record routing stats. Uses a default if null." - ), - gt=0, + description="Config for the request router used for this deployment.", ) @root_validator @@ -503,9 +487,7 @@ def _deployment_info_to_schema(name: str, info: DeploymentInfo) -> DeploymentSch health_check_period_s=info.deployment_config.health_check_period_s, health_check_timeout_s=info.deployment_config.health_check_timeout_s, ray_actor_options=info.replica_config.ray_actor_options, - request_router_class=info.deployment_config.request_router_class, - request_routing_stats_period_s=info.deployment_config.request_routing_stats_period_s, - request_routing_stats_timeout_s=info.deployment_config.request_routing_stats_timeout_s, + request_router_config=info.deployment_config.request_router_config, ) if info.deployment_config.autoscaling_config is not None: @@ -1203,15 +1185,17 @@ def _get_user_facing_json_serializable_dict( """Generates json serializable dictionary with user facing data.""" values = super().dict(*args, **kwargs) - # `serialized_policy_def` and `serialized_request_router_cls` are only used + # `serialized_policy_def` and internal router config fields are only used # internally and should not be exposed to the REST api. This method iteratively - # removes them from each deployment and autoscaling config if exists. + # removes them from each deployment config if exists. for app_name, application in values["applications"].items(): for deployment_name, deployment in application["deployments"].items(): if "deployment_config" in deployment: - deployment["deployment_config"].pop( - "serialized_request_router_cls", None - ) + # Remove internal fields from request_router_config if it exists + if "request_router_config" in deployment["deployment_config"]: + deployment["deployment_config"]["request_router_config"].pop( + "_serialized_request_router_cls", None + ) if "autoscaling_config" in deployment["deployment_config"]: deployment["deployment_config"]["autoscaling_config"].pop( "_serialized_policy_def", None diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index 35e671183f20..c74ee0467190 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -28,6 +28,7 @@ RequestRouter, ) from ray.serve._private.test_utils import get_application_url +from ray.serve.config import RequestRouterConfig from ray.serve.deployment import Application from ray.serve.exceptions import RayServeException from ray.serve.handle import DeploymentHandle @@ -80,8 +81,14 @@ async def choose_replicas( ) -> List[List[RunningReplica]]: return [candidate_replicas] + def initialize_state(self, test_parameter: int = 0): + print("Called initialize_state in FakeRequestRouter") + self.test_parameter = test_parameter -@serve.deployment(request_router_class=FakeRequestRouter) + +@serve.deployment( + request_router_config=RequestRouterConfig(request_router_class=FakeRequestRouter) +) class AppWithCustomRequestRouter: def __call__(self) -> str: return "Hello, world!" @@ -1112,6 +1119,24 @@ def test_deploy_app_with_custom_request_router(serve_instance): assert handle.remote().result() == "Hello, world!" +@serve.deployment( + request_router_config=RequestRouterConfig( + request_router_class="ray.serve.tests.test_api.FakeRequestRouter", + request_router_kwargs=dict(test_parameter=4848), + ) +) +class AppWithCustomRequestRouterAndKwargs: + def __call__(self) -> str: + return "Hello, world!" + + +def test_custom_request_router_kwargs(serve_instance): + """Check that custom kwargs can be passed to the request router.""" + + handle = serve.run(AppWithCustomRequestRouterAndKwargs.bind()) + assert handle.remote().result() == "Hello, world!" + + if __name__ == "__main__": import sys diff --git a/python/ray/serve/tests/test_controller.py b/python/ray/serve/tests/test_controller.py index eb14d6c0da40..d3696ca521fb 100644 --- a/python/ray/serve/tests/test_controller.py +++ b/python/ray/serve/tests/test_controller.py @@ -185,9 +185,12 @@ def autoscaling_app(): "ray_actor_options": { "num_cpus": 1.0, }, - "request_router_class": "ray.serve._private.request_router:PowerOfTwoChoicesRequestRouter", - "request_routing_stats_period_s": 10.0, - "request_routing_stats_timeout_s": 30.0, + "request_router_config": { + "request_router_class": "ray.serve._private.request_router:PowerOfTwoChoicesRequestRouter", + "request_router_kwargs": {}, + "request_routing_stats_period_s": 10.0, + "request_routing_stats_timeout_s": 30.0, + }, }, "target_num_replicas": 1, "required_resources": {"CPU": 1}, diff --git a/python/ray/serve/tests/test_record_routing_stats.py b/python/ray/serve/tests/test_record_routing_stats.py index fae97c6f22db..100e57859a32 100644 --- a/python/ray/serve/tests/test_record_routing_stats.py +++ b/python/ray/serve/tests/test_record_routing_stats.py @@ -7,12 +7,15 @@ from ray import serve from ray._common.test_utils import wait_for_condition from ray.serve._private.common import ReplicaID +from ray.serve.config import RequestRouterConfig from ray.serve.context import _get_internal_replica_context from ray.serve.handle import DeploymentHandle @serve.deployment( - request_routing_stats_period_s=0.1, request_routing_stats_timeout_s=0.1 + request_router_config=RequestRouterConfig( + request_routing_stats_period_s=0.1, request_routing_stats_timeout_s=0.1 + ) ) class Patient: def __init__(self): diff --git a/python/ray/serve/tests/test_telemetry_2.py b/python/ray/serve/tests/test_telemetry_2.py index 55f4eda95847..0413400d5bf2 100644 --- a/python/ray/serve/tests/test_telemetry_2.py +++ b/python/ray/serve/tests/test_telemetry_2.py @@ -17,6 +17,7 @@ ) from ray.serve._private.test_utils import check_apps_running, check_telemetry from ray.serve._private.usage import ServeUsageTag +from ray.serve.config import RequestRouterConfig from ray.serve.context import _get_global_client from ray.serve.schema import ServeDeploySchema @@ -159,7 +160,9 @@ def test_custom_request_router_telemetry(manage_ray_with_telemetry): check_telemetry(ServeUsageTag.CUSTOM_REQUEST_ROUTER_USED, expected=None) @serve.deployment( - request_router_class=CustomRequestRouter, + request_router_config=RequestRouterConfig( + request_router_class=CustomRequestRouter, + ), ) class CustomRequestRouterApp: async def __call__(self) -> str: diff --git a/python/ray/serve/tests/unit/test_config.py b/python/ray/serve/tests/unit/test_config.py index a1951f319455..41ef5010eed5 100644 --- a/python/ray/serve/tests/unit/test_config.py +++ b/python/ray/serve/tests/unit/test_config.py @@ -15,6 +15,7 @@ DeploymentMode, HTTPOptions, ProxyLocation, + RequestRouterConfig, gRPCOptions, ) from ray.serve.generated.serve_pb2 import ( @@ -144,27 +145,43 @@ def test_setting_and_getting_request_router_class(self): # Passing request_router_class as a class. deployment_config = DeploymentConfig.from_default( - request_router_class=FakeRequestRouter + request_router_config=RequestRouterConfig( + request_router_class=FakeRequestRouter + ) + ) + assert ( + deployment_config.request_router_config.request_router_class + == request_router_path + ) + assert ( + deployment_config.request_router_config.get_request_router_class() + == FakeRequestRouter ) - assert deployment_config.request_router_class == request_router_path - assert deployment_config.get_request_router_class() == FakeRequestRouter # Passing request_router_class as an import path. deployment_config = DeploymentConfig.from_default( - request_router_class=request_router_path + request_router_config=RequestRouterConfig( + request_router_class=request_router_path + ) + ) + assert ( + deployment_config.request_router_config.request_router_class + == request_router_path + ) + assert ( + deployment_config.request_router_config.get_request_router_class() + == FakeRequestRouter ) - assert deployment_config.request_router_class == request_router_path - assert deployment_config.get_request_router_class() == FakeRequestRouter # Not passing request_router_class should # default to `PowerOfTwoChoicesRequestRouter`. deployment_config = DeploymentConfig.from_default() assert ( - deployment_config.request_router_class + deployment_config.request_router_config.request_router_class == "ray.serve._private.request_router:PowerOfTwoChoicesRequestRouter" ) assert ( - deployment_config.get_request_router_class() + deployment_config.request_router_config.get_request_router_class() == PowerOfTwoChoicesRequestRouter ) diff --git a/src/ray/protobuf/serve.proto b/src/ray/protobuf/serve.proto index 62f6145680e5..ebfc36096a01 100644 --- a/src/ray/protobuf/serve.proto +++ b/src/ray/protobuf/serve.proto @@ -91,6 +91,27 @@ message LoggingConfig { //[End] Logging Config +//[Begin] ROUTING CONFIG +message RequestRouterConfig { + // Cloudpickled request router definition. + bytes _serialized_request_router_cls = 1; + + // The import path of the request router if user passed a string. It's the + // concatenation of the request router module and the request router name + // if user passed a callable. + string request_router_class = 2; + + // Frequency at which the controller records routing stats for a replica. + double request_routing_stats_period_s = 3; + + // Timeout after which a replica started a record routing stats without a response. + double request_routing_stats_timeout_s = 4; + + // kwargs which Ray Serve passes to the router class' initialize_state method. + bytes request_router_kwargs = 5; +} +//[End] ROUTING CONFIG + // Configuration options for a deployment, to be set by the user. message DeploymentConfig { // The number of processes to start up that will handle requests to this deployment. @@ -135,19 +156,8 @@ message DeploymentConfig { LoggingConfig logging_config = 14; - // Cloudpickled request router definition. - bytes serialized_request_router_cls = 15; - - // The import path of the request router if user passed a string. Will be the - // concatenation of the request router module and the request router name - // if user passed a callable. - string request_router_class = 16; - - // Frequency at which the controller records routing stats for a replica. - double request_routing_stats_period_s = 17; - - // Timeout after which a replica started a record routing stats without a response. - double request_routing_stats_timeout_s = 18; + // The deployment's routing configuration. + RequestRouterConfig request_router_config = 19; } // Deployment language.