diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 10a46422887e..52d3394a96a1 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -2,8 +2,12 @@ from dataclasses import dataclass import pytest +import torch -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm import SamplingParams +from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine + +from ..utils import wait_for_gpu_memory_to_clear @dataclass @@ -94,3 +98,35 @@ async def test_new_requests_event(): assert engine.get_model_config() is not None assert engine.get_tokenizer() is not None assert engine.get_decoding_config() is not None + + +def test_asyncio_run(): + wait_for_gpu_memory_to_clear( + devices=list(range(torch.cuda.device_count())), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) + + engine = AsyncLLMEngine.from_engine_args( + AsyncEngineArgs(model="facebook/opt-125m")) + + async def run(prompt: str): + sampling_params = SamplingParams( + temperature=0, + max_tokens=32, + ) + + async for output in engine.generate(prompt, + sampling_params, + request_id=prompt): + final_output = output + return final_output + + async def generate(): + return await asyncio.gather( + run("test0"), + run("test1"), + ) + + results = asyncio.run(generate()) + assert len(results) == 2 diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 86103cf85484..60dfe33f2918 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,5 +1,4 @@ import asyncio -import time from itertools import cycle from typing import Dict, List, Optional, Tuple, Union @@ -7,12 +6,6 @@ import ray import torch -from vllm.utils import is_hip - -if (not is_hip()): - from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, - nvmlInit) - from vllm import LLM from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -26,6 +19,7 @@ from vllm.utils import Counter, random_uuid from ...conftest import cleanup +from ...utils import wait_for_gpu_memory_to_clear class AsyncLLM: @@ -291,38 +285,3 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, print(f'{i=} {baseline_token_ids=}') print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids - - -def wait_for_gpu_memory_to_clear(devices: List[int], - threshold_bytes: int, - timeout_s: float = 120) -> None: - # Use nvml instead of pytorch to reduce measurement error from torch cuda - # context. - nvmlInit() - start_time = time.time() - while True: - output: Dict[int, str] = {} - output_raw: Dict[int, float] = {} - for device in devices: - dev_handle = nvmlDeviceGetHandleByIndex(device) - mem_info = nvmlDeviceGetMemoryInfo(dev_handle) - gb_used = mem_info.used / 2**30 - output_raw[device] = gb_used - output[device] = f'{gb_used:.02f}' - - print('gpu memory used (GB): ', end='') - for k, v in output.items(): - print(f'{k}={v}; ', end='') - print('') - - dur_s = time.time() - start_time - if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()): - print(f'Done waiting for free GPU memory on devices {devices=} ' - f'({threshold_bytes/2**30=}) {dur_s=:.02f}') - break - - if dur_s >= timeout_s: - raise ValueError(f'Memory of devices {devices=} not free after ' - f'{dur_s=:.02f} ({threshold_bytes/2**30=})') - - time.sleep(5) diff --git a/tests/utils.py b/tests/utils.py index f2b2d22b1ebc..bc30515c8310 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ import time import warnings from contextlib import contextmanager -from typing import List +from typing import Dict, List import openai import ray @@ -13,7 +13,11 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.utils import get_open_port +from vllm.utils import get_open_port, is_hip + +if (not is_hip()): + from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, + nvmlInit) # Path to root of repository so that utilities can be imported by ray workers VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)) @@ -154,3 +158,38 @@ def error_on_warning(): warnings.simplefilter("error") yield + + +def wait_for_gpu_memory_to_clear(devices: List[int], + threshold_bytes: int, + timeout_s: float = 120) -> None: + # Use nvml instead of pytorch to reduce measurement error from torch cuda + # context. + nvmlInit() + start_time = time.time() + while True: + output: Dict[int, str] = {} + output_raw: Dict[int, float] = {} + for device in devices: + dev_handle = nvmlDeviceGetHandleByIndex(device) + mem_info = nvmlDeviceGetMemoryInfo(dev_handle) + gb_used = mem_info.used / 2**30 + output_raw[device] = gb_used + output[device] = f'{gb_used:.02f}' + + print('gpu memory used (GB): ', end='') + for k, v in output.items(): + print(f'{k}={v}; ', end='') + print('') + + dur_s = time.time() - start_time + if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()): + print(f'Done waiting for free GPU memory on devices {devices=} ' + f'({threshold_bytes/2**30=}) {dur_s=:.02f}') + break + + if dur_s >= timeout_s: + raise ValueError(f'Memory of devices {devices=} not free after ' + f'{dur_s=:.02f} ({threshold_bytes/2**30=})') + + time.sleep(5) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index ab312850b9ec..fb4457c78be2 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -10,6 +10,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import LLMEngine from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.inputs import LLMInputs, PromptInputs @@ -540,8 +541,8 @@ async def run_engine_loop(self): # Abort if iteration takes too long due to unrecoverable errors # (eg. NCCL timeouts). try: - has_requests_in_progress = await asyncio.wait_for( - self.engine_step(), ENGINE_ITERATION_TIMEOUT_S) + async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S): + has_requests_in_progress = await self.engine_step() except asyncio.TimeoutError as exc: logger.error( "Engine iteration timed out. This should never happen!") diff --git a/vllm/engine/async_timeout.py b/vllm/engine/async_timeout.py new file mode 100644 index 000000000000..4b1842625212 --- /dev/null +++ b/vllm/engine/async_timeout.py @@ -0,0 +1,189 @@ +# Workaround for https://github.com/python/cpython/issues/86296 +# +# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py +# Licensed under the Apache License (Apache-2.0) + +import asyncio +import enum +import sys +import warnings +from types import TracebackType +from typing import Any, Optional, Type + +if sys.version_info[:2] >= (3, 11): + from asyncio import timeout as asyncio_timeout +else: + + def asyncio_timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + deadline = loop.time() + delay if delay is not None else None + return Timeout(deadline, loop) + + class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + + def __init__(self, deadline: Optional[float], + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._state = _State.INIT + + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + def __enter__(self) -> "Timeout": + warnings.warn( + "with timeout() is deprecated, use async with timeout()", + DeprecationWarning, + stacklevel=2, + ) + self._do_enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + The delay can be negative. + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError( + "cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + deadline argument points on the time in the same clock system + as loop.time(). + If new deadline is in the past the timeout is raised immediately. + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError( + "cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon( + self._on_timeout, task) + else: + self._timeout_handler = self._loop.call_at( + deadline, self._on_timeout, task) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and \ + self._state == _State.TIMEOUT: + self._timeout_handler = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None: + if task: + task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None