|
11 | 11 |
|
12 | 12 | import openai
|
13 | 13 | import requests
|
14 |
| -from huggingface_hub import snapshot_download |
15 | 14 | from transformers import AutoTokenizer
|
16 | 15 | from typing_extensions import ParamSpec
|
17 | 16 |
|
18 | 17 | from vllm.distributed import (ensure_model_parallel_initialized,
|
19 | 18 | init_distributed_environment)
|
| 19 | +from vllm.engine.arg_utils import AsyncEngineArgs |
20 | 20 | from vllm.entrypoints.openai.cli_args import make_arg_parser
|
| 21 | +from vllm.model_executor.model_loader.loader import DefaultModelLoader |
21 | 22 | from vllm.platforms import current_platform
|
22 | 23 | from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
|
23 | 24 |
|
@@ -60,39 +61,50 @@ class RemoteOpenAIServer:
|
60 | 61 |
|
61 | 62 | def __init__(self,
|
62 | 63 | model: str,
|
63 |
| - cli_args: List[str], |
| 64 | + vllm_serve_args: List[str], |
64 | 65 | *,
|
65 | 66 | env_dict: Optional[Dict[str, str]] = None,
|
66 | 67 | auto_port: bool = True,
|
67 | 68 | max_wait_seconds: Optional[float] = None) -> None:
|
68 |
| - if not model.startswith("/"): |
69 |
| - # download the model if it's not a local path |
70 |
| - # to exclude the model download time from the server start time |
71 |
| - snapshot_download(model) |
72 | 69 | if auto_port:
|
73 |
| - if "-p" in cli_args or "--port" in cli_args: |
74 |
| - raise ValueError("You have manually specified the port" |
| 70 | + if "-p" in vllm_serve_args or "--port" in vllm_serve_args: |
| 71 | + raise ValueError("You have manually specified the port " |
75 | 72 | "when `auto_port=True`.")
|
76 | 73 |
|
77 |
| - cli_args = cli_args + ["--port", str(get_open_port())] |
| 74 | + # Don't mutate the input args |
| 75 | + vllm_serve_args = vllm_serve_args + [ |
| 76 | + "--port", str(get_open_port()) |
| 77 | + ] |
78 | 78 |
|
79 | 79 | parser = FlexibleArgumentParser(
|
80 | 80 | description="vLLM's remote OpenAI server.")
|
81 | 81 | parser = make_arg_parser(parser)
|
82 |
| - args = parser.parse_args(cli_args) |
| 82 | + args = parser.parse_args(["--model", model, *vllm_serve_args]) |
83 | 83 | self.host = str(args.host or 'localhost')
|
84 | 84 | self.port = int(args.port)
|
85 | 85 |
|
| 86 | + # download the model before starting the server to avoid timeout |
| 87 | + is_local = os.path.isdir(model) |
| 88 | + if not is_local: |
| 89 | + engine_args = AsyncEngineArgs.from_cli_args(args) |
| 90 | + engine_config = engine_args.create_engine_config() |
| 91 | + dummy_loader = DefaultModelLoader(engine_config.load_config) |
| 92 | + dummy_loader._prepare_weights(engine_config.model_config.model, |
| 93 | + engine_config.model_config.revision, |
| 94 | + fall_back_to_pt=True) |
| 95 | + |
86 | 96 | env = os.environ.copy()
|
87 | 97 | # the current process might initialize cuda,
|
88 | 98 | # to be safe, we should use spawn method
|
89 | 99 | env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
90 | 100 | if env_dict is not None:
|
91 | 101 | env.update(env_dict)
|
92 |
| - self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args, |
93 |
| - env=env, |
94 |
| - stdout=sys.stdout, |
95 |
| - stderr=sys.stderr) |
| 102 | + self.proc = subprocess.Popen( |
| 103 | + ["vllm", "serve", model, *vllm_serve_args], |
| 104 | + env=env, |
| 105 | + stdout=sys.stdout, |
| 106 | + stderr=sys.stderr, |
| 107 | + ) |
96 | 108 | max_wait_seconds = max_wait_seconds or 240
|
97 | 109 | self._wait_for_server(url=self.url_for("health"),
|
98 | 110 | timeout=max_wait_seconds)
|
|
0 commit comments