Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/features/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ curl -X POST http://localhost:8000/v1/unload_lora_adapter \
}'
```

## Dynamically load LoRA Adapters from a directory

vLLM also supports setting a local directory to check for cached LoRA adapters. While dynamic LoRA is enabled, set
`--lora-cache-dir {path}`. When vLLM receives a request for a LoRA adapter `foobar` that it doesn't recognize, it
will check `path/foobar` for that adapter, load the adapter if able, and then service the request using that adapter.
Thereafter the adapter will be available for use by incoming requests as normal.

## New format for `--lora-modules`

In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:
Expand Down
11 changes: 11 additions & 0 deletions tests/entrypoints/openai/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0

import pytest


@pytest.fixture(scope='module')
def adapter_cache(request, tmpdir_factory):
# Create dir that mimics the structure of the adapter cache
adapter_cache = tmpdir_factory.mktemp(
request.module.__name__) / "adapter_cache"
return adapter_cache
33 changes: 31 additions & 2 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# imports for guided decoding tests
import json
import random
import re
import shutil
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -61,7 +62,7 @@ def zephyr_pa_files():

@pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
zephyr_pa_files):
zephyr_pa_files, adapter_cache):
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
Expand All @@ -80,6 +81,8 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
"64",
"--max-cpu-loras",
"2",
"--lora-cache-dir",
str(adapter_cache),
# pa config
"--enable-prompt-adapter",
"--prompt-adapters",
Expand All @@ -97,7 +100,12 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
def server(default_server_args, request):
if request.param:
default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:

lora_env = {
"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True",
}
with RemoteOpenAIServer(MODEL_NAME, default_server_args,
env_dict=lora_env) as remote_server:
yield remote_server


Expand Down Expand Up @@ -144,6 +152,27 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
assert completion.choices[0].prompt_logprobs is None


@pytest.mark.asyncio
async def test_cached_lora_completion(client: openai.AsyncOpenAI,
adapter_cache, zephyr_lora_files):
cached_lora_name = f"zephyr-7b-beta-lora-{random.random()}"
model_files = adapter_cache / cached_lora_name
shutil.copytree(zephyr_lora_files, model_files)

completion = await client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.choices[0].text == " Sarah and I am a"

lora_completion = await client.completions.create(
model=cached_lora_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.choices[0].text != lora_completion.choices[0].text


@pytest.mark.asyncio
async def test_added_lora_tokens(client: openai.AsyncOpenAI):
# test using token IDs
Expand Down
80 changes: 76 additions & 4 deletions tests/entrypoints/openai/test_lora_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import json
import random
import shutil
from contextlib import suppress

Expand Down Expand Up @@ -52,6 +53,11 @@ def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)


@pytest.fixture(scope="module")
def zephyr_nonlora_files():
return snapshot_download(repo_id=MODEL_NAME)


@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
Expand All @@ -62,7 +68,7 @@ def monkeypatch_module():

@pytest.fixture(scope="module", params=[False, True])
def server_with_lora_modules_json(request, monkeypatch_module,
zephyr_lora_files):
zephyr_lora_files, adapter_cache):

use_v1 = request.param
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
Expand Down Expand Up @@ -98,12 +104,16 @@ def server_with_lora_modules_json(request, monkeypatch_module,
"2",
"--max-num-seqs",
"64",
"--lora-cache-dir",
str(adapter_cache),
]

# Enable the /v1/load_lora_adapter endpoint
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
lora_env = {
"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True",
}

with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
with RemoteOpenAIServer(MODEL_NAME, args,
env_dict=lora_env) as remote_server:
yield remote_server


Expand Down Expand Up @@ -131,6 +141,68 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI,
assert lora_models[1].id == "zephyr-lora2"


@pytest.mark.asyncio
async def test_cached_non_lora_adapter(client: openai.AsyncOpenAI, tmp_path,
zephyr_nonlora_files, adapter_cache):
"""Validate that a cached model that isn't a lora adapter will not be
loaded from the cache directory"""
cached_nonlora_name = f"zephyr-7b-beta-{random.random()}"
model_files = adapter_cache / cached_nonlora_name
shutil.copytree(zephyr_nonlora_files, model_files)

models = await client.models.list()
models = models.data
lora_models = models[1:]
assert len(lora_models) == 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this ==2 ? are we also counting the adapter added in a previous tests ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the server is shared between all the tests in the module


with pytest.raises(openai.NotFoundError):
await client.completions.create(
model=cached_nonlora_name,
prompt=["Hello there", "Foo bar bazz buzz"],
max_tokens=5,
)

models = await client.models.list()
models = models.data
lora_models = models[1:]
assert len(lora_models) == 2
assert lora_models[0].id != cached_nonlora_name
assert lora_models[1].id != cached_nonlora_name


@pytest.mark.asyncio
async def test_cached_lora_adapter(client: openai.AsyncOpenAI, tmp_path,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice test 👍 !

I think a negative test case would be good too. Can we throw some junk in the adapter cache, or maybe an adapter for a different base model, and ensure that we get a graceful 400 when trying to run a chat completions with it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can add a negative test case for something that isn't a lora adapter, but how does a request know if it has a matching base model? the request only has a model (which is actually a lora adapter) set, i dont see anything that related the request to a base model, except that the lora adapter itself has its own base model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you try to load a lora adapter for a different model architecture than the base model, or a differently sized model than the base model, then the engine should fail to load it and return an error response which it looks like your code would already catch

zephyr_lora_files, adapter_cache):
"""Validate that a lora adapter can be dynamically discovered and loaded
from the cache directory"""
cached_lora_name = f"zephyr-7b-beta-lora-{random.random()}"
model_files = adapter_cache / cached_lora_name
shutil.copytree(zephyr_lora_files, model_files)

models = await client.models.list()
models = models.data
lora_models = models[1:]
assert len(lora_models) == 2
assert lora_models[0].id != cached_lora_name
assert lora_models[1].id != cached_lora_name

result = await client.completions.create(
model=cached_lora_name,
prompt=["Hello there", "Foo bar bazz buzz"],
max_tokens=5,
)

assert not isinstance(result, Exception), f"Got exception {result}"
assert isinstance(result, openai.types.Completion)
assert result.model == cached_lora_name

models = await client.models.list()
models = models.data
lora_models = models[1:]
assert len(lora_models) == 3
assert lora_models[2].id == cached_lora_name


@pytest.mark.asyncio
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI,
zephyr_lora_files):
Expand Down
17 changes: 16 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_lora_cache_args,
validate_parsed_serve_args)
# yapf conflicts with isort for this block
# yapf: disable
Expand Down Expand Up @@ -951,6 +952,7 @@ async def init_app_state(
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
lora_cache_dir=args.lora_cache_dir,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could local_cache_dir be part of OpenAIServingModels ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is probably a good idea, but it looks like i broke some tests when i just rebased on master. I'll get to this once i'm done squashing that.

chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
Expand All @@ -963,6 +965,7 @@ async def init_app_state(
engine_client,
model_config,
state.openai_serving_models,
lora_cache_dir=args.lora_cache_dir,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) if model_config.runner_type == "generate" else None
Expand All @@ -972,6 +975,7 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
lora_cache_dir=args.lora_cache_dir,
chat_template_content_format=args.chat_template_content_format,
) if model_config.runner_type == "pooling" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
Expand All @@ -980,24 +984,28 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
lora_cache_dir=args.lora_cache_dir,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None
state.openai_serving_scores = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
lora_cache_dir=args.lora_cache_dir,
request_logger=request_logger) if model_config.task in (
"score", "embed", "pooling") else None
state.jinaai_serving_reranking = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
lora_cache_dir=args.lora_cache_dir,
request_logger=request_logger
) if model_config.task == "score" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
state.openai_serving_models,
lora_cache_dir=args.lora_cache_dir,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
Expand All @@ -1006,6 +1014,7 @@ async def init_app_state(
engine_client,
model_config,
state.openai_serving_models,
lora_cache_dir=args.lora_cache_dir,
request_logger=request_logger,
) if model_config.runner_type == "transcription" else None
state.task = model_config.task
Expand Down Expand Up @@ -1067,7 +1076,12 @@ def signal_handler(*_) -> None:
app = build_app(args)

model_config = await engine_client.get_model_config()
await init_app_state(engine_client, model_config, app.state, args)
await init_app_state(
engine_client,
model_config,
app.state,
args,
)

def _listen_addr(a: str) -> str:
if is_valid_ipv6_address(a):
Expand Down Expand Up @@ -1113,5 +1127,6 @@ def _listen_addr(a: str) -> str:
parser = make_arg_parser(parser)
args = parser.parse_args()
validate_parsed_serve_args(args)
validate_lora_cache_args(args)

uvloop.run(run_server(args))
17 changes: 17 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from collections.abc import Sequence
from typing import Optional, Union, get_args

import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
Expand Down Expand Up @@ -124,6 +125,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"Example (new format): "
"``{\"name\": \"name\", \"path\": \"lora_path\", "
"\"base_model_name\": \"id\"}``")
parser.add_argument(
'--lora-cache-dir',
type=nullable_str,
default=None,
help=('Directory to look for LoRA adapters if an unknown adapter '
'is specified in a request. Requires '
'VLLM_ALLOW_RUNTIME_LORA_UPDATING to be enabled.'))
parser.add_argument(
"--prompt-adapters",
type=nullable_str,
Expand Down Expand Up @@ -290,6 +298,15 @@ def validate_parsed_serve_args(args: argparse.Namespace):
"--reasoning-parser")


def validate_lora_cache_args(args: argparse.Namespace):
"""Check that dynamic lora is enabled if the lora cache dir is set"""
if args.lora_cache_dir is not None and \
not envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
raise ValueError(
f"lora_cache_dir ({args.lora_cache_dir}) cannot be set if "
"VLLM_ALLOW_RUNTIME_LORA_UPDATING is not enabled")


def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
Expand Down
3 changes: 3 additions & 0 deletions vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ async def main(args):
model_config,
openai_serving_models,
args.response_role,
lora_cache_dir=None,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
Expand All @@ -339,6 +340,7 @@ async def main(args):
engine,
model_config,
openai_serving_models,
lora_cache_dir=None,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
Expand All @@ -347,6 +349,7 @@ async def main(args):
engine,
model_config,
openai_serving_models,
lora_cache_dir=None,
request_logger=request_logger,
) if model_config.task == "score" else None)

Expand Down
4 changes: 3 additions & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ def __init__(
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False,
lora_cache_dir: Optional[str] = None,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
lora_cache_dir=lora_cache_dir,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)

Expand Down Expand Up @@ -142,7 +144,7 @@ async def create_chat_completion(
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
) = await self._maybe_get_adapters(request)

model_name = self._get_model_name(request.model, lora_request)

Expand Down
Loading