Skip to content

Commit 029c71d

Browse files
[CI/Build] Avoid downloading all HF files in RemoteOpenAIServer (#7836)
1 parent 0b76999 commit 029c71d

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

tests/utils.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111

1212
import openai
1313
import requests
14-
from huggingface_hub import snapshot_download
1514
from transformers import AutoTokenizer
1615
from typing_extensions import ParamSpec
1716

1817
from vllm.distributed import (ensure_model_parallel_initialized,
1918
init_distributed_environment)
19+
from vllm.engine.arg_utils import AsyncEngineArgs
2020
from vllm.entrypoints.openai.cli_args import make_arg_parser
21+
from vllm.model_executor.model_loader.loader import DefaultModelLoader
2122
from vllm.platforms import current_platform
2223
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
2324

@@ -60,39 +61,50 @@ class RemoteOpenAIServer:
6061

6162
def __init__(self,
6263
model: str,
63-
cli_args: List[str],
64+
vllm_serve_args: List[str],
6465
*,
6566
env_dict: Optional[Dict[str, str]] = None,
6667
auto_port: bool = True,
6768
max_wait_seconds: Optional[float] = None) -> None:
68-
if not model.startswith("/"):
69-
# download the model if it's not a local path
70-
# to exclude the model download time from the server start time
71-
snapshot_download(model)
7269
if auto_port:
73-
if "-p" in cli_args or "--port" in cli_args:
74-
raise ValueError("You have manually specified the port"
70+
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
71+
raise ValueError("You have manually specified the port "
7572
"when `auto_port=True`.")
7673

77-
cli_args = cli_args + ["--port", str(get_open_port())]
74+
# Don't mutate the input args
75+
vllm_serve_args = vllm_serve_args + [
76+
"--port", str(get_open_port())
77+
]
7878

7979
parser = FlexibleArgumentParser(
8080
description="vLLM's remote OpenAI server.")
8181
parser = make_arg_parser(parser)
82-
args = parser.parse_args(cli_args)
82+
args = parser.parse_args(["--model", model, *vllm_serve_args])
8383
self.host = str(args.host or 'localhost')
8484
self.port = int(args.port)
8585

86+
# download the model before starting the server to avoid timeout
87+
is_local = os.path.isdir(model)
88+
if not is_local:
89+
engine_args = AsyncEngineArgs.from_cli_args(args)
90+
engine_config = engine_args.create_engine_config()
91+
dummy_loader = DefaultModelLoader(engine_config.load_config)
92+
dummy_loader._prepare_weights(engine_config.model_config.model,
93+
engine_config.model_config.revision,
94+
fall_back_to_pt=True)
95+
8696
env = os.environ.copy()
8797
# the current process might initialize cuda,
8898
# to be safe, we should use spawn method
8999
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
90100
if env_dict is not None:
91101
env.update(env_dict)
92-
self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args,
93-
env=env,
94-
stdout=sys.stdout,
95-
stderr=sys.stderr)
102+
self.proc = subprocess.Popen(
103+
["vllm", "serve", model, *vllm_serve_args],
104+
env=env,
105+
stdout=sys.stdout,
106+
stderr=sys.stderr,
107+
)
96108
max_wait_seconds = max_wait_seconds or 240
97109
self._wait_for_server(url=self.url_for("health"),
98110
timeout=max_wait_seconds)

vllm/engine/arg_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def from_cli_args(cls, args: argparse.Namespace):
742742
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
743743
return engine_args
744744

745-
def create_engine_config(self, ) -> EngineConfig:
745+
def create_engine_config(self) -> EngineConfig:
746746
# gguf file needs a specific model loader and doesn't use hf_repo
747747
if self.model.endswith(".gguf"):
748748
self.quantization = self.load_format = "gguf"

0 commit comments

Comments
 (0)