Skip to content

Commit 639712b

Browse files
committed
Allow dynamic loading of LoRA adapters in a cache dir
Signed-off-by: jberkhahn <[email protected]>
1 parent a79cc68 commit 639712b

File tree

15 files changed

+223
-18
lines changed

15 files changed

+223
-18
lines changed

docs/source/features/lora.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ curl -X POST http://localhost:8000/v1/unload_lora_adapter \
153153
}'
154154
```
155155

156+
## Dynamically load LoRA Adapters from a directory
157+
158+
vLLM also supports setting a local directory to check for cached LoRA adapters. While dynamic LoRA is enabled, set
159+
`--lora-cache-dir {path}`. When vLLM receives a request for a LoRA adapter `foobar` that it doesn't recognize, it
160+
will check `path/foobar` for that adapter, load the adapter if able, and then service the request using that adapter.
161+
Thereafter the adapter will be available for use by incoming requests as normal.
162+
156163
## New format for `--lora-modules`
157164

158165
In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
5+
6+
@pytest.fixture(scope='module')
7+
def adapter_cache(request, tmpdir_factory):
8+
# Create dir that mimics the structure of the adapter cache
9+
adapter_cache = tmpdir_factory.mktemp(
10+
request.module.__name__) / "adapter_cache"
11+
return adapter_cache

tests/entrypoints/openai/test_completion.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# imports for guided decoding tests
44
import json
5+
import random
56
import re
67
import shutil
78
from tempfile import TemporaryDirectory
@@ -61,7 +62,7 @@ def zephyr_pa_files():
6162

6263
@pytest.fixture(scope="module")
6364
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
64-
zephyr_pa_files):
65+
zephyr_pa_files, adapter_cache):
6566
return [
6667
# use half precision for speed and memory savings in CI environment
6768
"--dtype",
@@ -80,6 +81,8 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
8081
"64",
8182
"--max-cpu-loras",
8283
"2",
84+
"--lora-cache-dir",
85+
str(adapter_cache),
8386
# pa config
8487
"--enable-prompt-adapter",
8588
"--prompt-adapters",
@@ -97,7 +100,12 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
97100
def server(default_server_args, request):
98101
if request.param:
99102
default_server_args.append(request.param)
100-
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
103+
104+
lora_env = {
105+
"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True",
106+
}
107+
with RemoteOpenAIServer(MODEL_NAME, default_server_args,
108+
env_dict=lora_env) as remote_server:
101109
yield remote_server
102110

103111

@@ -144,6 +152,27 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
144152
assert completion.choices[0].prompt_logprobs is None
145153

146154

155+
@pytest.mark.asyncio
156+
async def test_cached_lora_completion(client: openai.AsyncOpenAI,
157+
adapter_cache, zephyr_lora_files):
158+
cached_lora_name = f"zephyr-7b-beta-lora-{random.random()}"
159+
model_files = adapter_cache / cached_lora_name
160+
shutil.copytree(zephyr_lora_files, model_files)
161+
162+
completion = await client.completions.create(model=MODEL_NAME,
163+
prompt="Hello, my name is",
164+
max_tokens=5,
165+
temperature=0.0)
166+
assert completion.choices[0].text == " Sarah and I am a"
167+
168+
lora_completion = await client.completions.create(
169+
model=cached_lora_name,
170+
prompt="Hello, my name is",
171+
max_tokens=5,
172+
temperature=0.0)
173+
assert completion.choices[0].text != lora_completion.choices[0].text
174+
175+
147176
@pytest.mark.asyncio
148177
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
149178
# test using token IDs

tests/entrypoints/openai/test_lora_adapters.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import json
5+
import random
56
import shutil
67
from contextlib import suppress
78

@@ -52,6 +53,11 @@ def zephyr_lora_files():
5253
return snapshot_download(repo_id=LORA_NAME)
5354

5455

56+
@pytest.fixture(scope="module")
57+
def zephyr_nonlora_files():
58+
return snapshot_download(repo_id=MODEL_NAME)
59+
60+
5561
@pytest.fixture(scope="module")
5662
def monkeypatch_module():
5763
from _pytest.monkeypatch import MonkeyPatch
@@ -62,7 +68,7 @@ def monkeypatch_module():
6268

6369
@pytest.fixture(scope="module", params=[False, True])
6470
def server_with_lora_modules_json(request, monkeypatch_module,
65-
zephyr_lora_files):
71+
zephyr_lora_files, adapter_cache):
6672

6773
use_v1 = request.param
6874
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
@@ -98,12 +104,16 @@ def server_with_lora_modules_json(request, monkeypatch_module,
98104
"2",
99105
"--max-num-seqs",
100106
"64",
107+
"--lora-cache-dir",
108+
str(adapter_cache),
101109
]
102110

103-
# Enable the /v1/load_lora_adapter endpoint
104-
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
111+
lora_env = {
112+
"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True",
113+
}
105114

106-
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
115+
with RemoteOpenAIServer(MODEL_NAME, args,
116+
env_dict=lora_env) as remote_server:
107117
yield remote_server
108118

109119

@@ -131,6 +141,68 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI,
131141
assert lora_models[1].id == "zephyr-lora2"
132142

133143

144+
@pytest.mark.asyncio
145+
async def test_cached_non_lora_adapter(client: openai.AsyncOpenAI, tmp_path,
146+
zephyr_nonlora_files, adapter_cache):
147+
"""Validate that a cached model that isn't a lora adapter will not be
148+
loaded from the cache directory"""
149+
cached_nonlora_name = f"zephyr-7b-beta-{random.random()}"
150+
model_files = adapter_cache / cached_nonlora_name
151+
shutil.copytree(zephyr_nonlora_files, model_files)
152+
153+
models = await client.models.list()
154+
models = models.data
155+
lora_models = models[1:]
156+
assert len(lora_models) == 2
157+
158+
with pytest.raises(openai.NotFoundError):
159+
await client.completions.create(
160+
model=cached_nonlora_name,
161+
prompt=["Hello there", "Foo bar bazz buzz"],
162+
max_tokens=5,
163+
)
164+
165+
models = await client.models.list()
166+
models = models.data
167+
lora_models = models[1:]
168+
assert len(lora_models) == 2
169+
assert lora_models[0].id != cached_nonlora_name
170+
assert lora_models[1].id != cached_nonlora_name
171+
172+
173+
@pytest.mark.asyncio
174+
async def test_cached_lora_adapter(client: openai.AsyncOpenAI, tmp_path,
175+
zephyr_lora_files, adapter_cache):
176+
"""Validate that a lora adapter can be dynamically discovered and loaded
177+
from the cache directory"""
178+
cached_lora_name = f"zephyr-7b-beta-lora-{random.random()}"
179+
model_files = adapter_cache / cached_lora_name
180+
shutil.copytree(zephyr_lora_files, model_files)
181+
182+
models = await client.models.list()
183+
models = models.data
184+
lora_models = models[1:]
185+
assert len(lora_models) == 2
186+
assert lora_models[0].id != cached_lora_name
187+
assert lora_models[1].id != cached_lora_name
188+
189+
result = await client.completions.create(
190+
model=cached_lora_name,
191+
prompt=["Hello there", "Foo bar bazz buzz"],
192+
max_tokens=5,
193+
)
194+
195+
assert not isinstance(result, Exception), f"Got exception {result}"
196+
assert isinstance(result, openai.types.Completion)
197+
assert result.model == cached_lora_name
198+
199+
models = await client.models.list()
200+
models = models.data
201+
lora_models = models[1:]
202+
assert len(lora_models) == 3
203+
assert lora_models[2].id == cached_lora_name
204+
205+
134206
@pytest.mark.asyncio
135207
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI,
136208
zephyr_lora_files):

vllm/entrypoints/openai/api_server.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from vllm.entrypoints.launcher import serve_http
4343
from vllm.entrypoints.logger import RequestLogger
4444
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
45+
validate_lora_cache_args,
4546
validate_parsed_serve_args)
4647
# yapf conflicts with isort for this block
4748
# yapf: disable
@@ -951,6 +952,7 @@ async def init_app_state(
951952
args.response_role,
952953
request_logger=request_logger,
953954
chat_template=resolved_chat_template,
955+
lora_cache_dir=args.lora_cache_dir,
954956
chat_template_content_format=args.chat_template_content_format,
955957
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
956958
enable_auto_tools=args.enable_auto_tool_choice,
@@ -963,6 +965,7 @@ async def init_app_state(
963965
engine_client,
964966
model_config,
965967
state.openai_serving_models,
968+
lora_cache_dir=args.lora_cache_dir,
966969
request_logger=request_logger,
967970
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
968971
) if model_config.runner_type == "generate" else None
@@ -972,6 +975,7 @@ async def init_app_state(
972975
state.openai_serving_models,
973976
request_logger=request_logger,
974977
chat_template=resolved_chat_template,
978+
lora_cache_dir=args.lora_cache_dir,
975979
chat_template_content_format=args.chat_template_content_format,
976980
) if model_config.runner_type == "pooling" else None
977981
state.openai_serving_embedding = OpenAIServingEmbedding(
@@ -980,24 +984,28 @@ async def init_app_state(
980984
state.openai_serving_models,
981985
request_logger=request_logger,
982986
chat_template=resolved_chat_template,
987+
lora_cache_dir=args.lora_cache_dir,
983988
chat_template_content_format=args.chat_template_content_format,
984989
) if model_config.task == "embed" else None
985990
state.openai_serving_scores = ServingScores(
986991
engine_client,
987992
model_config,
988993
state.openai_serving_models,
994+
lora_cache_dir=args.lora_cache_dir,
989995
request_logger=request_logger) if model_config.task in (
990996
"score", "embed", "pooling") else None
991997
state.jinaai_serving_reranking = ServingScores(
992998
engine_client,
993999
model_config,
9941000
state.openai_serving_models,
1001+
lora_cache_dir=args.lora_cache_dir,
9951002
request_logger=request_logger
9961003
) if model_config.task == "score" else None
9971004
state.openai_serving_tokenization = OpenAIServingTokenization(
9981005
engine_client,
9991006
model_config,
10001007
state.openai_serving_models,
1008+
lora_cache_dir=args.lora_cache_dir,
10011009
request_logger=request_logger,
10021010
chat_template=resolved_chat_template,
10031011
chat_template_content_format=args.chat_template_content_format,
@@ -1006,6 +1014,7 @@ async def init_app_state(
10061014
engine_client,
10071015
model_config,
10081016
state.openai_serving_models,
1017+
lora_cache_dir=args.lora_cache_dir,
10091018
request_logger=request_logger,
10101019
) if model_config.runner_type == "transcription" else None
10111020
state.task = model_config.task
@@ -1067,7 +1076,12 @@ def signal_handler(*_) -> None:
10671076
app = build_app(args)
10681077

10691078
model_config = await engine_client.get_model_config()
1070-
await init_app_state(engine_client, model_config, app.state, args)
1079+
await init_app_state(
1080+
engine_client,
1081+
model_config,
1082+
app.state,
1083+
args,
1084+
)
10711085

10721086
def _listen_addr(a: str) -> str:
10731087
if is_valid_ipv6_address(a):
@@ -1113,5 +1127,6 @@ def _listen_addr(a: str) -> str:
11131127
parser = make_arg_parser(parser)
11141128
args = parser.parse_args()
11151129
validate_parsed_serve_args(args)
1130+
validate_lora_cache_args(args)
11161131

11171132
uvloop.run(run_server(args))

vllm/entrypoints/openai/cli_args.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from collections.abc import Sequence
1212
from typing import Optional, Union, get_args
1313

14+
import vllm.envs as envs
1415
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
1516
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
1617
validate_chat_template)
@@ -124,6 +125,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
124125
"Example (new format): "
125126
"``{\"name\": \"name\", \"path\": \"lora_path\", "
126127
"\"base_model_name\": \"id\"}``")
128+
parser.add_argument(
129+
'--lora-cache-dir',
130+
type=nullable_str,
131+
default=None,
132+
help=('Directory to look for LoRA adapters if an unknown adapter '
133+
'is specified in a request. Requires '
134+
'VLLM_ALLOW_RUNTIME_LORA_UPDATING to be enabled.'))
127135
parser.add_argument(
128136
"--prompt-adapters",
129137
type=nullable_str,
@@ -290,6 +298,15 @@ def validate_parsed_serve_args(args: argparse.Namespace):
290298
"--reasoning-parser")
291299

292300

301+
def validate_lora_cache_args(args: argparse.Namespace):
302+
"""Check that dynamic lora is enabled if the lora cache dir is set"""
303+
if args.lora_cache_dir is not None and \
304+
not envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
305+
raise ValueError(
306+
f"lora_cache_dir ({args.lora_cache_dir}) cannot be set if "
307+
"VLLM_ALLOW_RUNTIME_LORA_UPDATING is not enabled")
308+
309+
293310
def create_parser_for_docs() -> FlexibleArgumentParser:
294311
parser_for_docs = FlexibleArgumentParser(
295312
prog="-m vllm.entrypoints.openai.api_server")

vllm/entrypoints/openai/run_batch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ async def main(args):
330330
model_config,
331331
openai_serving_models,
332332
args.response_role,
333+
lora_cache_dir=None,
333334
request_logger=request_logger,
334335
chat_template=None,
335336
chat_template_content_format="auto",
@@ -339,6 +340,7 @@ async def main(args):
339340
engine,
340341
model_config,
341342
openai_serving_models,
343+
lora_cache_dir=None,
342344
request_logger=request_logger,
343345
chat_template=None,
344346
chat_template_content_format="auto",
@@ -347,6 +349,7 @@ async def main(args):
347349
engine,
348350
model_config,
349351
openai_serving_models,
352+
lora_cache_dir=None,
350353
request_logger=request_logger,
351354
) if model_config.task == "score" else None)
352355

vllm/entrypoints/openai/serving_chat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@ def __init__(
5959
enable_auto_tools: bool = False,
6060
tool_parser: Optional[str] = None,
6161
enable_prompt_tokens_details: bool = False,
62+
lora_cache_dir: Optional[str] = None,
6263
) -> None:
6364
super().__init__(engine_client=engine_client,
6465
model_config=model_config,
6566
models=models,
67+
lora_cache_dir=lora_cache_dir,
6668
request_logger=request_logger,
6769
return_tokens_as_token_ids=return_tokens_as_token_ids)
6870

@@ -142,7 +144,7 @@ async def create_chat_completion(
142144
(
143145
lora_request,
144146
prompt_adapter_request,
145-
) = self._maybe_get_adapters(request)
147+
) = await self._maybe_get_adapters(request)
146148

147149
model_name = self._get_model_name(request.model, lora_request)
148150

0 commit comments

Comments
 (0)