|
20 | 20 | """
|
21 | 21 | import contextlib
|
22 | 22 | import enum
|
| 23 | +import logging |
23 | 24 | import multiprocessing
|
24 | 25 | import os
|
| 26 | +import pathlib |
25 | 27 | import queue as queue_lib
|
26 | 28 | import signal
|
27 | 29 | import sys
|
|
45 | 47 | from sky.server.requests.queues import local_queue
|
46 | 48 | from sky.server.requests.queues import mp_queue
|
47 | 49 | from sky.skylet import constants
|
48 |
| -from sky.utils import annotations |
| 50 | +from sky.utils import annotations, context |
49 | 51 | from sky.utils import common_utils
|
50 | 52 | from sky.utils import subprocess_utils
|
51 | 53 | from sky.utils import timeline
|
|
60 | 62 | from typing_extensions import ParamSpec
|
61 | 63 |
|
62 | 64 | P = ParamSpec('P')
|
63 |
| - |
64 | 65 | logger = sky_logging.init_logger(__name__)
|
65 | 66 |
|
66 | 67 | # On macOS, the default start method for multiprocessing is 'fork', which
|
@@ -325,7 +326,6 @@ def _restore_output(original_stdout: int, original_stderr: int) -> None:
|
325 | 326 | def _sigterm_handler(signum: int, frame: Optional['types.FrameType']) -> None:
|
326 | 327 | raise KeyboardInterrupt
|
327 | 328 |
|
328 |
| - |
329 | 329 | def _request_execution_wrapper(request_id: str,
|
330 | 330 | ignore_return_value: bool) -> None:
|
331 | 331 | """Wrapper for a request execution.
|
@@ -389,6 +389,57 @@ def _request_execution_wrapper(request_id: str,
|
389 | 389 | _restore_output(original_stdout, original_stderr)
|
390 | 390 | logger.info(f'Request {request_id} finished')
|
391 | 391 |
|
| 392 | +async def execute_request(request: api_requests.Request): |
| 393 | + """Execute a request in current event loop. |
| 394 | + |
| 395 | + Similar to _request_execution_wrapper, but executed as coroutine in current |
| 396 | + event loop. This is designed for executing tasks that are not CPU |
| 397 | + intensive, e.g. sky logs. |
| 398 | + """ |
| 399 | + ctx = context.get() |
| 400 | + if ctx is None: |
| 401 | + raise ValueError('Context is not initialized') |
| 402 | + log_path = request.log_path |
| 403 | + func = request.entrypoint |
| 404 | + request_body = request.request_body |
| 405 | + with api_requests.update_request(request.request_id) as request_task: |
| 406 | + request_task.status = api_requests.RequestStatus.RUNNING |
| 407 | + ctx.log_handler = logging.FileHandler(log_path.absolute()) |
| 408 | + sky_logging.reload_logger() |
| 409 | + |
| 410 | +def prepare_request( |
| 411 | + request_id: str, |
| 412 | + request_name: str, |
| 413 | + request_body: payloads.RequestBody, |
| 414 | + func: Callable[P, Any], |
| 415 | + request_cluster_name: Optional[str] = None, |
| 416 | + schedule_type: api_requests.ScheduleType = ( |
| 417 | + api_requests.ScheduleType.LONG), |
| 418 | + is_skypilot_system: bool = False, |
| 419 | +) -> Optional[api_requests.Request]: |
| 420 | + """Prepare a request for execution.""" |
| 421 | + user_id = request_body.env_vars[constants.USER_ID_ENV_VAR] |
| 422 | + if is_skypilot_system: |
| 423 | + user_id = server_constants.SKYPILOT_SYSTEM_USER_ID |
| 424 | + global_user_state.add_or_update_user( |
| 425 | + models.User(id=user_id, name=user_id)) |
| 426 | + request = api_requests.Request(request_id=request_id, |
| 427 | + name=server_constants.REQUEST_NAME_PREFIX + |
| 428 | + request_name, |
| 429 | + entrypoint=func, |
| 430 | + request_body=request_body, |
| 431 | + status=api_requests.RequestStatus.PENDING, |
| 432 | + created_at=time.time(), |
| 433 | + schedule_type=schedule_type, |
| 434 | + user_id=user_id, |
| 435 | + cluster_name=request_cluster_name) |
| 436 | + |
| 437 | + if not api_requests.create_if_not_exists(request): |
| 438 | + logger.debug(f'Request {request_id} already exists.') |
| 439 | + return None |
| 440 | + |
| 441 | + request.log_path.touch() |
| 442 | + return request |
392 | 443 |
|
393 | 444 | def schedule_request(
|
394 | 445 | request_id: str,
|
@@ -421,28 +472,16 @@ def schedule_request(
|
421 | 472 | The precondition is waited asynchronously and does not block the
|
422 | 473 | caller.
|
423 | 474 | """
|
424 |
| - user_id = request_body.env_vars[constants.USER_ID_ENV_VAR] |
425 |
| - if is_skypilot_system: |
426 |
| - user_id = server_constants.SKYPILOT_SYSTEM_USER_ID |
427 |
| - global_user_state.add_or_update_user( |
428 |
| - models.User(id=user_id, name=user_id)) |
429 |
| - request = api_requests.Request(request_id=request_id, |
430 |
| - name=server_constants.REQUEST_NAME_PREFIX + |
431 |
| - request_name, |
432 |
| - entrypoint=func, |
433 |
| - request_body=request_body, |
434 |
| - status=api_requests.RequestStatus.PENDING, |
435 |
| - created_at=time.time(), |
436 |
| - schedule_type=schedule_type, |
437 |
| - user_id=user_id, |
438 |
| - cluster_name=request_cluster_name) |
439 |
| - |
440 |
| - if not api_requests.create_if_not_exists(request): |
441 |
| - logger.debug(f'Request {request_id} already exists.') |
| 475 | + request = prepare_request(request_id, |
| 476 | + request_name, |
| 477 | + request_body, |
| 478 | + func, |
| 479 | + request_cluster_name, |
| 480 | + schedule_type, |
| 481 | + is_skypilot_system) |
| 482 | + if request is None: |
442 | 483 | return
|
443 | 484 |
|
444 |
| - request.log_path.touch() |
445 |
| - |
446 | 485 | def enqueue():
|
447 | 486 | input_tuple = (request_id, ignore_return_value)
|
448 | 487 | logger.info(f'Queuing request: {request_id}')
|
|
0 commit comments