-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[V1] Multiprocessing Tensor Parallel Support for v1 #9856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5ad9c60
49869fa
71e08aa
4930246
3ea0cae
d4b55ae
feeed73
e3c9c5c
254714d
10a627e
d95c01e
c08bae4
2392755
bf3705c
d4ea706
2174a5b
25270ab
9322db5
b5bac31
c4fcfce
bedd593
c03ef6d
8d9d557
5f3a570
b59babc
66116c7
eaeebc3
6d53d6e
a7025fb
6a3f2da
d4e3813
52ef894
9f9883e
1990433
f8a1b9b
963c97f
0678911
3d71b53
88c9c7b
ab7cb89
024bcad
d77bab5
24ffb8a
be4260f
cb4b363
365ea06
c94e11b
2a36db7
536e5f2
998eb1d
ebb2544
c81b7f5
0817336
f10e5e8
e49b071
661278f
fce9696
c61a3e0
8bb2430
5271ec6
50a12bc
edab869
e0aea84
ce08cb2
65b79c4
143ed09
ab6bf27
819b229
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,14 @@ | |
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def v1(run_with_both_engines): | ||
# Simple autouse wrapper to run both engines for each test | ||
# This can be promoted up to conftest.py to run for every | ||
# test in a package | ||
pass | ||
|
||
|
||
def test_vllm_gc_ed(): | ||
"""Verify vllm instance is GC'ed when it is deleted""" | ||
llm = LLM("facebook/opt-125m") | ||
|
@@ -36,6 +44,7 @@ def test_vllm_gc_ed(): | |
assert weak_llm() is None | ||
|
||
|
||
@pytest.mark.skip_v1 | ||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) | ||
@pytest.mark.parametrize("dtype", ["half"]) | ||
|
@@ -118,6 +127,11 @@ def test_models_distributed( | |
if attention_backend: | ||
os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend | ||
|
||
# Import VLLM_USE_V1 dynamically to handle patching | ||
from vllm.envs import VLLM_USE_V1 | ||
if VLLM_USE_V1 and distributed_executor_backend != "mp": | ||
pytest.skip(f"Skip {distributed_executor_backend} for V1") | ||
|
||
dtype = "half" | ||
max_tokens = 5 | ||
|
||
|
@@ -143,6 +157,7 @@ def test_models_distributed( | |
) | ||
|
||
|
||
@pytest.mark.skip_v1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this skipped? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test fails on V1 but I don't know why. It's not related to this PR as it's not running TP and fails on current main (just enabled it on #10864) |
||
def test_model_with_failure(vllm_runner) -> None: | ||
try: | ||
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward", | ||
|
@@ -169,6 +184,7 @@ def test_model_with_failure(vllm_runner) -> None: | |
os.remove(filename) | ||
|
||
|
||
@pytest.mark.skip_v1 | ||
def test_failure_with_async_out_proc(vllm_runner) -> None: | ||
|
||
filename = None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
import os | ||
import pickle | ||
import sys | ||
import time | ||
from contextlib import contextmanager | ||
from dataclasses import dataclass, field | ||
from multiprocessing import shared_memory | ||
from typing import List, Optional | ||
from typing import List, Optional, Tuple | ||
from unittest.mock import patch | ||
|
||
import torch | ||
|
@@ -21,6 +22,20 @@ | |
|
||
logger = init_logger(__name__) | ||
|
||
# We prefer to use os.sched_yield as it results in tighter polling loops, | ||
# measured to be around 3e-7 seconds. However on earlier versions of Python | ||
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) | ||
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) | ||
or (sys.version_info[:2] == (3, 10) | ||
and sys.version_info[2] >= 8)) | ||
|
||
|
||
def sched_yield(): | ||
if USE_SCHED_YIELD: | ||
os.sched_yield() | ||
else: | ||
time.sleep(0) | ||
|
||
|
||
class ShmRingBuffer: | ||
|
||
|
@@ -114,11 +129,14 @@ def __init__(self, | |
# and we should suppress the error | ||
pass | ||
|
||
def handle(self): | ||
return (self.n_reader, self.max_chunk_bytes, self.max_chunks, | ||
self.shared_memory.name) | ||
|
||
def __reduce__(self): | ||
return ( | ||
self.__class__, | ||
(self.n_reader, self.max_chunk_bytes, self.max_chunks, | ||
self.shared_memory.name), | ||
self.handle(), | ||
) | ||
|
||
def __del__(self): | ||
|
@@ -147,7 +165,7 @@ class Handle: | |
connect_ip: str | ||
local_reader_ranks: List[int] = field(default_factory=list) | ||
|
||
buffer: Optional[ShmRingBuffer] = None | ||
buffer_handle: Optional[Tuple[int, int, int, str]] = None | ||
local_subscribe_port: Optional[int] = None | ||
remote_subscribe_port: Optional[int] = None | ||
|
||
|
@@ -228,7 +246,7 @@ def __init__( | |
self.handle = Handle( | ||
connect_ip=connect_ip, | ||
local_reader_ranks=local_reader_ranks, | ||
buffer=self.buffer, | ||
buffer_handle=self.buffer.handle(), | ||
local_subscribe_port=local_subscribe_port, | ||
remote_subscribe_port=remote_subscribe_port, | ||
) | ||
|
@@ -247,8 +265,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": | |
context = Context() | ||
|
||
if rank in handle.local_reader_ranks: | ||
assert handle.buffer is not None | ||
self.buffer = handle.buffer | ||
assert handle.buffer_handle is not None | ||
self.buffer = ShmRingBuffer(*handle.buffer_handle) | ||
self.current_idx = 0 | ||
self.local_reader_rank = handle.local_reader_ranks.index(rank) | ||
self._is_local_reader = True | ||
|
@@ -314,7 +332,7 @@ def wait_until_ready(self): | |
assert recv == b"READY" | ||
|
||
@contextmanager | ||
def acquire_write(self): | ||
def acquire_write(self, timeout: Optional[float] = None): | ||
assert self._is_writer, "Only writers can acquire write" | ||
start_time = time.monotonic() | ||
n_warning = 1 | ||
|
@@ -329,16 +347,20 @@ def acquire_write(self): | |
# we need to wait until it is read by all readers | ||
|
||
# Release the processor to other threads | ||
os.sched_yield() | ||
sched_yield() | ||
|
||
# if we wait for a long time, we should warn the user | ||
# if we wait for a long time, log a message | ||
if (time.monotonic() - start_time > | ||
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): | ||
logger.warning( | ||
"No available block found in %s second. ", | ||
VLLM_RINGBUFFER_WARNING_INTERVAL) | ||
logger.debug("No available block found in %s second. ", | ||
VLLM_RINGBUFFER_WARNING_INTERVAL) | ||
n_warning += 1 | ||
|
||
# if we time out, raise an exception | ||
if (timeout is not None | ||
and time.monotonic() - start_time > timeout): | ||
raise TimeoutError | ||
|
||
continue | ||
# found a block that is either | ||
# (1) not written | ||
|
@@ -365,7 +387,7 @@ def acquire_write(self): | |
break | ||
|
||
@contextmanager | ||
def acquire_read(self): | ||
def acquire_read(self, timeout: Optional[float] = None): | ||
assert self._is_local_reader, "Only readers can acquire read" | ||
start_time = time.monotonic() | ||
n_warning = 1 | ||
|
@@ -383,16 +405,20 @@ def acquire_read(self): | |
# we need to wait until it is written | ||
|
||
# Release the processor to other threads | ||
os.sched_yield() | ||
sched_yield() | ||
|
||
# if we wait for a long time, we should warn the user | ||
# if we wait for a long time, log a message | ||
if (time.monotonic() - start_time > | ||
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): | ||
logger.warning( | ||
"No available block found in %s second. ", | ||
VLLM_RINGBUFFER_WARNING_INTERVAL) | ||
logger.debug("No available block found in %s second. ", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems we don't log this message at all by default now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @youkaichao any idea how to suppress this message when vLLM is idle? Right now on current main if we're not serving any requests, we will tell users that vLLM is probably hanging every 60 seconds 😆 |
||
VLLM_RINGBUFFER_WARNING_INTERVAL) | ||
n_warning += 1 | ||
|
||
# if we time out, raise an exception | ||
if (timeout is not None | ||
and time.monotonic() - start_time > timeout): | ||
raise TimeoutError | ||
|
||
continue | ||
# found a block that is not read by this reader | ||
# let caller read from the buffer | ||
|
@@ -406,24 +432,26 @@ def acquire_read(self): | |
1) % self.buffer.max_chunks | ||
break | ||
|
||
def enqueue(self, obj): | ||
def enqueue(self, obj, timeout: Optional[float] = None): | ||
""" Write to message queue with optional timeout (in seconds) """ | ||
assert self._is_writer, "Only writers can enqueue" | ||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) | ||
if self.n_local_reader > 0: | ||
if len(serialized_obj) >= self.buffer.max_chunk_bytes: | ||
with self.acquire_write() as buf: | ||
with self.acquire_write(timeout) as buf: | ||
buf[0] = 1 # overflow | ||
self.local_socket.send(serialized_obj) | ||
else: | ||
with self.acquire_write() as buf: | ||
with self.acquire_write(timeout) as buf: | ||
buf[0] = 0 # not overflow | ||
buf[1:len(serialized_obj) + 1] = serialized_obj | ||
if self.n_remote_reader > 0: | ||
self.remote_socket.send(serialized_obj) | ||
|
||
def dequeue(self): | ||
def dequeue(self, timeout: Optional[float] = None): | ||
""" Read from message queue with optional timeout (in seconds) """ | ||
if self._is_local_reader: | ||
with self.acquire_read() as buf: | ||
with self.acquire_read(timeout) as buf: | ||
overflow = buf[0] == 1 | ||
if not overflow: | ||
# no need to know the size of serialized object | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this dynamic patching mean?
envs.VLLM_USE_V1
should read the latest env var value.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the import here to get the
VLLM_USE_V1
check using when we are using therun_with_both_engines
pytest fixture during testing.vllm/tests/conftest.py
Lines 112 to 126 in bf0e382
Please LMK if you have a better idea!