Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions launch_scripts/launch_vllm
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from types import SimpleNamespace
import subprocess
import threading
from dataclasses import dataclass
from vllm import ServerArgs, LLMServer, SamplingParams
from vllm import EngineArgs, LLMEngine, SamplingParams
import os
os.environ["TRANSFORMERS_CACHE"] = '/data/cache'

Expand Down Expand Up @@ -88,7 +88,7 @@ class ModelThread:

needs_call_progress = False
for vllm_output in vllm_outputs:
if not vllm_output.finished():
if not vllm_output.finished:
continue

needs_call_progress = True
Expand All @@ -111,8 +111,8 @@ class ModelThread:
@staticmethod
def init_model(vllm_args):
print('Init model')
server_args = ServerArgs.from_cli_args(vllm_args)
server = LLMServer.from_server_args(server_args)
server_args = EngineArgs.from_cli_args(vllm_args)
server = LLMEngine.from_engine_args(server_args)
print('Model ready')
return server

Expand Down Expand Up @@ -225,10 +225,10 @@ async def is_ready(request: Request):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int, required=True)
ServerArgs.add_cli_args(parser)
EngineArgs.add_cli_args(parser)
args = parser.parse_args()

vllm_args = ServerArgs.from_cli_args(args)
vllm_args = EngineArgs.from_cli_args(args)

loop = asyncio.new_event_loop()
server = FastAPIServer(loop, vllm_args)
Expand Down
10 changes: 6 additions & 4 deletions launch_scripts/launch_vllm_ray_serve
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class VLLMPredictDeployment:
yield (json.dumps(ret) + "\0").encode("utf-8")

async def abort_request() -> None:
await engine.abort(request_id)
await self.engine.abort(request_id)

if stream:
background_tasks = BackgroundTasks()
Expand All @@ -64,7 +64,7 @@ class VLLMPredictDeployment:
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
await self.engine.abort(request_id)
return Response(status_code=499)
final_output = request_output

Expand Down Expand Up @@ -93,8 +93,10 @@ if __name__ == "__main__":
parser.add_argument('--port', type=int, required=True)
args = parser.parse_args()

model = 'facebook/opt-13b'
model = 'facebook/opt-125m' #'facebook/opt-13b'
deployment = VLLMPredictDeployment.bind(
model=model, max_num_batched_tokens=8100, use_np_weights=True)
model=model, max_num_batched_tokens=8100)
serve.run(deployment, port=args.port)
send_request()
while True:
pass