Skip to content

Commit e13e76e

Browse files
randyjhccomaniac
authored andcommitted
[Feature] Add vllm bench CLI (vllm-project#13993)
Signed-off-by: Randy Chen <[email protected]> Signed-off-by: Cody Yu <[email protected]> Co-authored-by: Cody Yu <[email protected]> Signed-off-by: Richard Liu <[email protected]>
1 parent de167bc commit e13e76e

File tree

8 files changed

+1274
-0
lines changed

8 files changed

+1274
-0
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""The request function for API endpoints."""
3+
4+
import json
5+
import os
6+
import sys
7+
import time
8+
import traceback
9+
from dataclasses import dataclass, field
10+
from typing import Optional
11+
12+
import aiohttp
13+
from tqdm.asyncio import tqdm
14+
15+
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
16+
17+
18+
@dataclass
19+
class RequestFuncInput:
20+
"""The input for the request function."""
21+
prompt: str
22+
api_url: str
23+
prompt_len: int
24+
output_len: int
25+
model: str
26+
model_name: Optional[str] = None
27+
best_of: int = 1
28+
logprobs: Optional[int] = None
29+
extra_body: Optional[dict] = None
30+
multi_modal_content: Optional[dict] = None
31+
ignore_eos: bool = False
32+
33+
34+
@dataclass
35+
class RequestFuncOutput:
36+
"""The output of the request function including metrics."""
37+
generated_text: str = ""
38+
success: bool = False
39+
latency: float = 0.0
40+
output_tokens: int = 0
41+
ttft: float = 0.0 # Time to first token
42+
itl: list[float] = field(
43+
default_factory=list) # list of inter-token latencies
44+
tpot: float = 0.0 # avg next-token latencies
45+
prompt_len: int = 0
46+
error: str = ""
47+
48+
49+
async def async_request_openai_completions(
50+
request_func_input: RequestFuncInput,
51+
pbar: Optional[tqdm] = None,
52+
) -> RequestFuncOutput:
53+
"""The async request function for the OpenAI Completions API.
54+
55+
Args:
56+
request_func_input: The input for the request function.
57+
pbar: The progress bar to display the progress.
58+
59+
Returns:
60+
The output of the request function.
61+
"""
62+
api_url = request_func_input.api_url
63+
assert api_url.endswith(
64+
("completions", "profile")
65+
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
66+
67+
async with aiohttp.ClientSession(trust_env=True,
68+
timeout=AIOHTTP_TIMEOUT) as session:
69+
payload = {
70+
"model": request_func_input.model_name \
71+
if request_func_input.model_name else request_func_input.model,
72+
"prompt": request_func_input.prompt,
73+
"temperature": 0.0,
74+
"best_of": request_func_input.best_of,
75+
"max_tokens": request_func_input.output_len,
76+
"logprobs": request_func_input.logprobs,
77+
"stream": True,
78+
"stream_options": {
79+
"include_usage": True,
80+
},
81+
}
82+
if request_func_input.ignore_eos:
83+
payload["ignore_eos"] = request_func_input.ignore_eos
84+
if request_func_input.extra_body:
85+
payload.update(request_func_input.extra_body)
86+
headers = {
87+
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
88+
}
89+
90+
output = RequestFuncOutput()
91+
output.prompt_len = request_func_input.prompt_len
92+
93+
generated_text = ""
94+
st = time.perf_counter()
95+
most_recent_timestamp = st
96+
try:
97+
async with session.post(url=api_url, json=payload,
98+
headers=headers) as response:
99+
if response.status == 200:
100+
first_chunk_received = False
101+
async for chunk_bytes in response.content:
102+
chunk_bytes = chunk_bytes.strip()
103+
if not chunk_bytes:
104+
continue
105+
106+
chunk = chunk_bytes.decode("utf-8").removeprefix(
107+
"data: ")
108+
if chunk != "[DONE]":
109+
data = json.loads(chunk)
110+
111+
# NOTE: Some completion API might have a last
112+
# usage summary response without a token so we
113+
# want to check a token was generated
114+
if choices := data.get("choices"):
115+
# Note that text could be empty here
116+
# e.g. for special tokens
117+
text = choices[0].get("text")
118+
timestamp = time.perf_counter()
119+
# First token
120+
if not first_chunk_received:
121+
first_chunk_received = True
122+
ttft = time.perf_counter() - st
123+
output.ttft = ttft
124+
125+
# Decoding phase
126+
else:
127+
output.itl.append(timestamp -
128+
most_recent_timestamp)
129+
130+
most_recent_timestamp = timestamp
131+
generated_text += text or ""
132+
elif usage := data.get("usage"):
133+
output.output_tokens = usage.get(
134+
"completion_tokens")
135+
if first_chunk_received:
136+
output.success = True
137+
else:
138+
output.success = False
139+
output.error = (
140+
"Never received a valid chunk to calculate TTFT."
141+
"This response will be marked as failed!")
142+
output.generated_text = generated_text
143+
output.latency = most_recent_timestamp - st
144+
else:
145+
output.error = response.reason or ""
146+
output.success = False
147+
except Exception:
148+
output.success = False
149+
exc_info = sys.exc_info()
150+
output.error = "".join(traceback.format_exception(*exc_info))
151+
152+
if pbar:
153+
pbar.update(1)
154+
return output
155+
156+
157+
# TODO: Add more request functions for different API protocols.
158+
ASYNC_REQUEST_FUNCS = {
159+
"openai-comp": async_request_openai_completions,
160+
}

0 commit comments

Comments
 (0)