Skip to content

Commit b520958

Browse files
authored
[router][grpc] Replace fake health check with correct ones (#11387)
1 parent fa7e2c3 commit b520958

File tree

2 files changed

+97
-11
lines changed

2 files changed

+97
-11
lines changed

python/sglang/srt/entrypoints/grpc_request_manager.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
TokenizedEmbeddingReqInput,
2828
TokenizedGenerateReqInput,
2929
)
30+
from sglang.srt.managers.scheduler import is_health_check_generate_req
3031
from sglang.srt.server_args import PortArgs, ServerArgs
3132
from sglang.srt.utils import get_zmq_socket, kill_process_tree
3233
from sglang.utils import get_exception_traceback
@@ -338,12 +339,9 @@ async def _handle_single_request(
338339
break
339340

340341
except asyncio.TimeoutError:
341-
# Timeout waiting for response - abort and cleanup
342-
logger.warning(
343-
f"Timeout waiting for response for request {request_id}"
344-
)
345-
await self.abort_request(request_id)
346-
return
342+
# Timeout is for periodic client cancellation check
343+
# Continue waiting for scheduler response
344+
continue
347345

348346
finally:
349347
# Always clean up request state when exiting
@@ -412,6 +410,10 @@ async def wait_for_result():
412410

413411
async def abort_request(self, request_id: str) -> bool:
414412
"""Abort a running request."""
413+
# Skip aborting health check requests (they clean themselves up)
414+
if request_id.startswith("HEALTH_CHECK"):
415+
return False
416+
415417
if request_id not in self.rid_to_state:
416418
return False
417419

python/sglang/srt/entrypoints/grpc_server.py

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ async def Generate(
197197
context: grpc.aio.ServicerContext,
198198
) -> AsyncIterator[sglang_scheduler_pb2.GenerateResponse]:
199199
"""Handle generation requests with streaming responses."""
200-
logger.debug(f"Receive generation request: {request.request_id}")
200+
logger.info(f"Receive generation request: {request.request_id}")
201201

202202
try:
203203
# Convert gRPC request to internal format
@@ -211,6 +211,13 @@ async def Generate(
211211
)
212212

213213
async for output in response_generator:
214+
# Check if client cancelled before processing/yielding
215+
if context.cancelled():
216+
logger.info(f"Client cancelled request {request.request_id}")
217+
# Explicitly abort the request to notify scheduler
218+
await self.request_manager.abort_request(request.request_id)
219+
break
220+
214221
# Handle batch responses (for n>1 non-streaming)
215222
if isinstance(output, list):
216223
for batch_output in output:
@@ -268,7 +275,7 @@ async def Embed(
268275
_context: grpc.aio.ServicerContext,
269276
) -> sglang_scheduler_pb2.EmbedResponse:
270277
"""Handle embedding requests."""
271-
logger.debug(f"Receive embedding request: {request.request_id}")
278+
logger.info(f"Receive embedding request: {request.request_id}")
272279

273280
try:
274281
# Convert request
@@ -313,9 +320,86 @@ async def HealthCheck(
313320
request: sglang_scheduler_pb2.HealthCheckRequest,
314321
context: grpc.aio.ServicerContext,
315322
) -> sglang_scheduler_pb2.HealthCheckResponse:
316-
"""Health check - always returns healthy after server started."""
323+
"""
324+
Check the health of the inference server by sending a special request to generate one token.
325+
Similar to HTTP server's /health endpoint.
326+
"""
327+
logger.info("Receive health check request")
328+
329+
if self.request_manager.gracefully_exit:
330+
logger.info(
331+
"Health check request received during shutdown. Returning unhealthy."
332+
)
333+
return sglang_scheduler_pb2.HealthCheckResponse(
334+
healthy=False, message="Server is shutting down"
335+
)
336+
337+
# Create a special health check request
338+
rid = f"HEALTH_CHECK_{time.time()}"
339+
sampling_params = SGLSamplingParams(max_new_tokens=1, temperature=0.0)
340+
sampling_params.normalize(tokenizer=None)
341+
342+
# Create health check request
343+
is_generation = self.scheduler_info.get("is_generation", True)
344+
if is_generation:
345+
health_req = TokenizedGenerateReqInput(
346+
rid=rid,
347+
input_text="",
348+
input_ids=[0],
349+
sampling_params=sampling_params,
350+
return_logprob=False,
351+
logprob_start_len=-1,
352+
top_logprobs_num=0,
353+
stream=False,
354+
mm_inputs=None,
355+
token_ids_logprob=None,
356+
)
357+
# Set disaggregation params if needed
358+
if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
359+
health_req.bootstrap_host = FAKE_BOOTSTRAP_HOST
360+
health_req.bootstrap_room = 0
361+
else:
362+
health_req = TokenizedEmbeddingReqInput(
363+
rid=rid,
364+
input_text="",
365+
input_ids=[0],
366+
)
367+
368+
# Submit health check request
369+
async def run_health_check():
370+
try:
371+
async for _ in self.request_manager.generate_request(
372+
obj=health_req,
373+
request_id=rid,
374+
):
375+
# Got at least one response, server is healthy
376+
return True
377+
except Exception as e:
378+
logger.warning(f"Health check failed: {e}")
379+
return False
380+
return False
381+
382+
task = asyncio.create_task(run_health_check())
383+
384+
# Wait for response with timeout
385+
tic = time.time()
386+
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
387+
await asyncio.sleep(1)
388+
# Check if we got a response from scheduler
389+
if self.request_manager.last_receive_tstamp > tic:
390+
task.cancel()
391+
# Clean up health check state
392+
self.request_manager._cleanup_request_state(rid)
393+
return sglang_scheduler_pb2.HealthCheckResponse(
394+
healthy=True, message="Health check passed"
395+
)
396+
397+
# Timeout - server not responding
398+
task.cancel()
399+
self.request_manager._cleanup_request_state(rid)
400+
logger.warning(f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s")
317401
return sglang_scheduler_pb2.HealthCheckResponse(
318-
healthy=True, message="Health check passed"
402+
healthy=False, message=f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s"
319403
)
320404

321405
async def Abort(
@@ -324,7 +408,7 @@ async def Abort(
324408
_context: grpc.aio.ServicerContext,
325409
) -> sglang_scheduler_pb2.AbortResponse:
326410
"""Abort an ongoing request."""
327-
logger.debug(f"Receive abort request: {request.request_id}")
411+
logger.info(f"Receive abort request: {request.request_id}")
328412

329413
try:
330414
success = await self.request_manager.abort_request(request.request_id)

0 commit comments

Comments
 (0)