Skip to content

Commit bf48b7d

Browse files
committed
[API server] handle logs request in event loop
Signed-off-by: Aylei <[email protected]>
1 parent 50fd062 commit bf48b7d

16 files changed

+387
-88
lines changed

sky/backends/backend_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from sky.utils import command_runner
4141
from sky.utils import common
4242
from sky.utils import common_utils
43+
from sky.utils import context
4344
from sky.utils import controller_utils
4445
from sky.utils import env_options
4546
from sky.utils import registry
@@ -2185,6 +2186,7 @@ def refresh_cluster_record(
21852186

21862187

21872188
@timeline.event
2189+
@context.cancellation_guard
21882190
def refresh_cluster_status_handle(
21892191
cluster_name: str,
21902192
*,
@@ -2234,6 +2236,7 @@ def check_cluster_available(
22342236
...
22352237

22362238

2239+
@context.cancellation_guard
22372240
def check_cluster_available(
22382241
cluster_name: str,
22392242
*,

sky/backends/cloud_vm_ray_backend.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from sky.skylet import job_lib
5656
from sky.skylet import log_lib
5757
from sky.usage import usage_lib
58-
from sky.utils import accelerator_registry
58+
from sky.utils import accelerator_registry, context
5959
from sky.utils import annotations
6060
from sky.utils import cluster_utils
6161
from sky.utils import command_runner
@@ -2410,6 +2410,7 @@ def is_provided_ips_valid(
24102410
internal_external_ips[1:], key=lambda x: x[1])
24112411
self.stable_internal_external_ips = stable_internal_external_ips
24122412

2413+
@context.cancellation_guard
24132414
@annotations.lru_cache(scope='global')
24142415
@timeline.event
24152416
def get_command_runners(self,
@@ -3817,6 +3818,7 @@ def _rsync_down(args) -> None:
38173818
subprocess_utils.run_in_parallel(_rsync_down, parallel_args)
38183819
return dict(zip(job_ids, local_log_dirs))
38193820

3821+
@context.cancellation_guard
38203822
def tail_logs(self,
38213823
handle: CloudVmRayResourceHandle,
38223824
job_id: Optional[int],
@@ -4534,6 +4536,7 @@ def is_definitely_autostopping(self,
45344536
# TODO(zhwu): Refactor this to a CommandRunner class, so different backends
45354537
# can support its own command runner.
45364538
@timeline.event
4539+
@context.cancellation_guard
45374540
def run_on_head(
45384541
self,
45394542
handle: CloudVmRayResourceHandle,

sky/client/sdk.py

+3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def stream_response(request_id: Optional[str],
8383
logger.debug(f'To stream request logs: sky api logs {request_id}')
8484
raise
8585

86+
def decode_stream(response: requests.Response,
87+
output_stream: Optional['io.TextIOBase'] = None) -> str:
88+
8689

8790
@usage_lib.entrypoint
8891
@server_common.check_server_healthy_or_start

sky/core.py

+3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sky.utils import admin_policy_utils
3232
from sky.utils import common
3333
from sky.utils import common_utils
34+
from sky.utils import context
3435
from sky.utils import controller_utils
3536
from sky.utils import rich_utils
3637
from sky.utils import status_lib
@@ -851,6 +852,7 @@ def tail_logs(cluster_name: str,
851852
script.
852853
853854
"""
855+
logger.info('tail_log: go to backend_utils.check_cluster_available')
854856
# Check the status of the cluster.
855857
handle = backend_utils.check_cluster_available(
856858
cluster_name,
@@ -859,6 +861,7 @@ def tail_logs(cluster_name: str,
859861
backend = backend_utils.get_backend_from_handle(handle)
860862

861863
usage_lib.record_cluster_name_for_current_operation(cluster_name)
864+
logger.info('tail_log: go to backend.tail_logs')
862865
return backend.tail_logs(handle, job_id, follow=follow, tail=tail)
863866

864867

sky/global_user_state.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from sky import models
2020
from sky import sky_logging
21-
from sky.utils import common_utils
21+
from sky.utils import common_utils, context
2222
from sky.utils import db_utils
2323
from sky.utils import registry
2424
from sky.utils import status_lib
@@ -670,6 +670,7 @@ def _load_storage_mounts_metadata(
670670
return pickle.loads(record_storage_mounts_metadata)
671671

672672

673+
@context.cancellation_guard
673674
def get_cluster_from_name(
674675
cluster_name: Optional[str]) -> Optional[Dict[str, Any]]:
675676
rows = _DB.cursor.execute(

sky/server/requests/executor.py

+62-23
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
"""
2121
import contextlib
2222
import enum
23+
import logging
2324
import multiprocessing
2425
import os
26+
import pathlib
2527
import queue as queue_lib
2628
import signal
2729
import sys
@@ -45,7 +47,7 @@
4547
from sky.server.requests.queues import local_queue
4648
from sky.server.requests.queues import mp_queue
4749
from sky.skylet import constants
48-
from sky.utils import annotations
50+
from sky.utils import annotations, context
4951
from sky.utils import common_utils
5052
from sky.utils import subprocess_utils
5153
from sky.utils import timeline
@@ -60,7 +62,6 @@
6062
from typing_extensions import ParamSpec
6163

6264
P = ParamSpec('P')
63-
6465
logger = sky_logging.init_logger(__name__)
6566

6667
# On macOS, the default start method for multiprocessing is 'fork', which
@@ -325,7 +326,6 @@ def _restore_output(original_stdout: int, original_stderr: int) -> None:
325326
def _sigterm_handler(signum: int, frame: Optional['types.FrameType']) -> None:
326327
raise KeyboardInterrupt
327328

328-
329329
def _request_execution_wrapper(request_id: str,
330330
ignore_return_value: bool) -> None:
331331
"""Wrapper for a request execution.
@@ -389,6 +389,57 @@ def _request_execution_wrapper(request_id: str,
389389
_restore_output(original_stdout, original_stderr)
390390
logger.info(f'Request {request_id} finished')
391391

392+
async def execute_request(request: api_requests.Request):
393+
"""Execute a request in current event loop.
394+
395+
Similar to _request_execution_wrapper, but executed as coroutine in current
396+
event loop. This is designed for executing tasks that are not CPU
397+
intensive, e.g. sky logs.
398+
"""
399+
ctx = context.get()
400+
if ctx is None:
401+
raise ValueError('Context is not initialized')
402+
log_path = request.log_path
403+
func = request.entrypoint
404+
request_body = request.request_body
405+
with api_requests.update_request(request.request_id) as request_task:
406+
request_task.status = api_requests.RequestStatus.RUNNING
407+
ctx.log_handler = logging.FileHandler(log_path.absolute())
408+
sky_logging.reload_logger()
409+
410+
def prepare_request(
411+
request_id: str,
412+
request_name: str,
413+
request_body: payloads.RequestBody,
414+
func: Callable[P, Any],
415+
request_cluster_name: Optional[str] = None,
416+
schedule_type: api_requests.ScheduleType = (
417+
api_requests.ScheduleType.LONG),
418+
is_skypilot_system: bool = False,
419+
) -> Optional[api_requests.Request]:
420+
"""Prepare a request for execution."""
421+
user_id = request_body.env_vars[constants.USER_ID_ENV_VAR]
422+
if is_skypilot_system:
423+
user_id = server_constants.SKYPILOT_SYSTEM_USER_ID
424+
global_user_state.add_or_update_user(
425+
models.User(id=user_id, name=user_id))
426+
request = api_requests.Request(request_id=request_id,
427+
name=server_constants.REQUEST_NAME_PREFIX +
428+
request_name,
429+
entrypoint=func,
430+
request_body=request_body,
431+
status=api_requests.RequestStatus.PENDING,
432+
created_at=time.time(),
433+
schedule_type=schedule_type,
434+
user_id=user_id,
435+
cluster_name=request_cluster_name)
436+
437+
if not api_requests.create_if_not_exists(request):
438+
logger.debug(f'Request {request_id} already exists.')
439+
return None
440+
441+
request.log_path.touch()
442+
return request
392443

393444
def schedule_request(
394445
request_id: str,
@@ -421,28 +472,16 @@ def schedule_request(
421472
The precondition is waited asynchronously and does not block the
422473
caller.
423474
"""
424-
user_id = request_body.env_vars[constants.USER_ID_ENV_VAR]
425-
if is_skypilot_system:
426-
user_id = server_constants.SKYPILOT_SYSTEM_USER_ID
427-
global_user_state.add_or_update_user(
428-
models.User(id=user_id, name=user_id))
429-
request = api_requests.Request(request_id=request_id,
430-
name=server_constants.REQUEST_NAME_PREFIX +
431-
request_name,
432-
entrypoint=func,
433-
request_body=request_body,
434-
status=api_requests.RequestStatus.PENDING,
435-
created_at=time.time(),
436-
schedule_type=schedule_type,
437-
user_id=user_id,
438-
cluster_name=request_cluster_name)
439-
440-
if not api_requests.create_if_not_exists(request):
441-
logger.debug(f'Request {request_id} already exists.')
475+
request = prepare_request(request_id,
476+
request_name,
477+
request_body,
478+
func,
479+
request_cluster_name,
480+
schedule_type,
481+
is_skypilot_system)
482+
if request is None:
442483
return
443484

444-
request.log_path.touch()
445-
446485
def enqueue():
447486
input_tuple = (request_id, ignore_return_value)
448487
logger.info(f'Queuing request: {request_id}')

sky/server/requests/requests.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import sqlite3
1212
import time
1313
import traceback
14-
from typing import Any, Callable, Dict, List, Optional, Tuple
14+
from typing import Any, Callable, Dict, List, Optional, ParamSpec, Tuple
1515

1616
import colorama
1717
import filelock

sky/server/server.py

+32-22
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import contextlib
66
import dataclasses
77
import datetime
8+
import functools
89
import logging
910
import multiprocessing
1011
import os
@@ -43,7 +44,7 @@
4344
from sky.server.requests import requests as requests_lib
4445
from sky.skylet import constants
4546
from sky.usage import usage_lib
46-
from sky.utils import admin_policy_utils
47+
from sky.utils import admin_policy_utils, context
4748
from sky.utils import common as common_lib
4849
from sky.utils import common_utils
4950
from sky.utils import dag_utils
@@ -101,7 +102,6 @@ async def dispatch(self, request: fastapi.Request, call_next):
101102
response.headers['X-Request-ID'] = request_id
102103
return response
103104

104-
105105
# Default expiration time for upload ids before cleanup.
106106
_DEFAULT_UPLOAD_EXPIRATION_TIME = datetime.timedelta(hours=1)
107107
# Key: (upload_id, user_hash), Value: the time when the upload id needs to be
@@ -666,31 +666,41 @@ async def logs(
666666
background_tasks: fastapi.BackgroundTasks
667667
) -> fastapi.responses.StreamingResponse:
668668
"""Tails the logs of a job."""
669+
del request
669670
# TODO(zhwu): This should wait for the request on the cluster, e.g., async
670671
# launch, to finish, so that a user does not need to manually pull the
671672
# request status.
672-
executor.schedule_request(
673-
request_id=request.state.request_id,
674-
request_name='logs',
675-
request_body=cluster_job_body,
676-
func=core.tail_logs,
677-
# TODO(aylei): We have tail logs scheduled as SHORT request, because it
678-
# should be responsive. However, it can be long running if the user's
679-
# job keeps running, and we should avoid it taking the SHORT worker.
680-
schedule_type=requests_lib.ScheduleType.SHORT,
681-
request_cluster_name=cluster_job_body.cluster_name,
682-
)
683-
684-
request_task = requests_lib.get_request(request.state.request_id)
685-
673+
# Only initialize the context in logs handler to limit the scope of this
674+
# experimental change.
675+
# TODO(aylei): init in lifespan() to enable SkyPilot context in all APIs.
676+
logger.info('Initializing context')
677+
context.initialize()
678+
ctx = context.get()
679+
# ctx.env_vars = cluster_job_body.env_vars
680+
# buffer = stream_utils.StreamingBuffer()
681+
# ctx.log_stream = buffer
682+
logger.info('Starting tail logs')
683+
task = asyncio.to_thread(core.tail_logs, **cluster_job_body.to_kwargs())
684+
logger.info('Tail logs started')
685+
686+
async def on_disconnect():
687+
ctx.cancel()
688+
# buffer.close()
689+
await task
690+
691+
background_tasks.add_task(on_disconnect)
692+
693+
await task
686694
# TODO(zhwu): This makes viewing logs in browser impossible. We should adopt
687695
# the same approach as /stream.
688-
return stream_utils.stream_response(
689-
request_id=request_task.request_id,
690-
logs_path=request_task.log_path,
691-
background_tasks=background_tasks,
692-
)
693-
696+
return fastapi.responses.StreamingResponse(
697+
"Hello",
698+
media_type='text/plain',
699+
headers={
700+
'Cache-Control': 'no-cache, no-transform',
701+
'X-Accel-Buffering': 'no',
702+
'Transfer-Encoding': 'chunked'
703+
})
694704

695705
@app.get('/users')
696706
async def users() -> List[Dict[str, Any]]:

sky/server/stream_utils.py

+46
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import collections
5+
import io
56
import pathlib
67
from typing import AsyncGenerator, Deque, Optional
78

@@ -142,3 +143,48 @@ async def on_disconnect():
142143
'X-Accel-Buffering': 'no',
143144
'Transfer-Encoding': 'chunked'
144145
})
146+
147+
class StreamingBuffer:
148+
"""Memory backed streaming buffer."""
149+
def __init__(self):
150+
self._buffer = io.StringIO()
151+
self._event = asyncio.Event()
152+
self._closed = False
153+
154+
def write(self, data: str) -> int:
155+
"""Write data to the buffer."""
156+
if self._closed:
157+
raise ValueError('Buffer is closed')
158+
n = self._buffer.write(data)
159+
self._event.set()
160+
return n
161+
162+
def flush(self) -> None:
163+
"""Flush the buffer."""
164+
self._buffer.flush()
165+
166+
def close(self) -> None:
167+
"""Close the buffer."""
168+
self._closed = True
169+
self._event.set()
170+
171+
async def read(self) -> AsyncGenerator[str, None]:
172+
"""Read from the buffer as a stream."""
173+
# Start position in the buffer
174+
pos = 0
175+
176+
while True:
177+
# Get current buffer contents
178+
current = self._buffer.getvalue()
179+
if pos < len(current):
180+
# New data available, yield it
181+
chunk = current[pos:]
182+
pos = len(current)
183+
yield chunk
184+
elif self._closed:
185+
# Buffer is closed and no more data
186+
break
187+
else:
188+
# Wait for new data
189+
self._event.clear()
190+
await self._event.wait()

0 commit comments

Comments
 (0)