-
-
Notifications
You must be signed in to change notification settings - Fork 11.6k
Allow dynamic loading of LoRA adapters in a cache dir #14634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import pytest | ||
|
|
||
|
|
||
| @pytest.fixture(scope='module') | ||
| def adapter_cache(request, tmpdir_factory): | ||
| # Create dir that mimics the structure of the adapter cache | ||
| adapter_cache = tmpdir_factory.mktemp( | ||
| request.module.__name__) / "adapter_cache" | ||
| return adapter_cache |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import asyncio | ||
| import json | ||
| import random | ||
| import shutil | ||
| from contextlib import suppress | ||
|
|
||
|
|
@@ -52,6 +53,11 @@ def zephyr_lora_files(): | |
| return snapshot_download(repo_id=LORA_NAME) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def zephyr_nonlora_files(): | ||
| return snapshot_download(repo_id=MODEL_NAME) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def monkeypatch_module(): | ||
| from _pytest.monkeypatch import MonkeyPatch | ||
|
|
@@ -62,7 +68,7 @@ def monkeypatch_module(): | |
|
|
||
| @pytest.fixture(scope="module", params=[False, True]) | ||
| def server_with_lora_modules_json(request, monkeypatch_module, | ||
| zephyr_lora_files): | ||
| zephyr_lora_files, adapter_cache): | ||
|
|
||
| use_v1 = request.param | ||
| monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') | ||
|
|
@@ -98,12 +104,16 @@ def server_with_lora_modules_json(request, monkeypatch_module, | |
| "2", | ||
| "--max-num-seqs", | ||
| "64", | ||
| "--lora-cache-dir", | ||
| str(adapter_cache), | ||
| ] | ||
|
|
||
| # Enable the /v1/load_lora_adapter endpoint | ||
| envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"} | ||
| lora_env = { | ||
| "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True", | ||
| } | ||
|
|
||
| with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server: | ||
| with RemoteOpenAIServer(MODEL_NAME, args, | ||
| env_dict=lora_env) as remote_server: | ||
| yield remote_server | ||
|
|
||
|
|
||
|
|
@@ -131,6 +141,68 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI, | |
| assert lora_models[1].id == "zephyr-lora2" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_cached_non_lora_adapter(client: openai.AsyncOpenAI, tmp_path, | ||
| zephyr_nonlora_files, adapter_cache): | ||
| """Validate that a cached model that isn't a lora adapter will not be | ||
| loaded from the cache directory""" | ||
| cached_nonlora_name = f"zephyr-7b-beta-{random.random()}" | ||
| model_files = adapter_cache / cached_nonlora_name | ||
| shutil.copytree(zephyr_nonlora_files, model_files) | ||
|
|
||
| models = await client.models.list() | ||
| models = models.data | ||
| lora_models = models[1:] | ||
| assert len(lora_models) == 2 | ||
|
|
||
| with pytest.raises(openai.NotFoundError): | ||
| await client.completions.create( | ||
| model=cached_nonlora_name, | ||
| prompt=["Hello there", "Foo bar bazz buzz"], | ||
| max_tokens=5, | ||
| ) | ||
|
|
||
| models = await client.models.list() | ||
| models = models.data | ||
| lora_models = models[1:] | ||
| assert len(lora_models) == 2 | ||
| assert lora_models[0].id != cached_nonlora_name | ||
| assert lora_models[1].id != cached_nonlora_name | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_cached_lora_adapter(client: openai.AsyncOpenAI, tmp_path, | ||
|
||
| zephyr_lora_files, adapter_cache): | ||
| """Validate that a lora adapter can be dynamically discovered and loaded | ||
| from the cache directory""" | ||
| cached_lora_name = f"zephyr-7b-beta-lora-{random.random()}" | ||
| model_files = adapter_cache / cached_lora_name | ||
| shutil.copytree(zephyr_lora_files, model_files) | ||
|
|
||
| models = await client.models.list() | ||
| models = models.data | ||
| lora_models = models[1:] | ||
| assert len(lora_models) == 2 | ||
| assert lora_models[0].id != cached_lora_name | ||
| assert lora_models[1].id != cached_lora_name | ||
|
|
||
| result = await client.completions.create( | ||
| model=cached_lora_name, | ||
| prompt=["Hello there", "Foo bar bazz buzz"], | ||
| max_tokens=5, | ||
| ) | ||
|
|
||
| assert not isinstance(result, Exception), f"Got exception {result}" | ||
| assert isinstance(result, openai.types.Completion) | ||
| assert result.model == cached_lora_name | ||
|
|
||
| models = await client.models.list() | ||
| models = models.data | ||
| lora_models = models[1:] | ||
| assert len(lora_models) == 3 | ||
| assert lora_models[2].id == cached_lora_name | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, | ||
| zephyr_lora_files): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,6 +42,7 @@ | |
| from vllm.entrypoints.launcher import serve_http | ||
| from vllm.entrypoints.logger import RequestLogger | ||
| from vllm.entrypoints.openai.cli_args import (make_arg_parser, | ||
| validate_lora_cache_args, | ||
| validate_parsed_serve_args) | ||
| # yapf conflicts with isort for this block | ||
| # yapf: disable | ||
|
|
@@ -951,6 +952,7 @@ async def init_app_state( | |
| args.response_role, | ||
| request_logger=request_logger, | ||
| chat_template=resolved_chat_template, | ||
| lora_cache_dir=args.lora_cache_dir, | ||
|
||
| chat_template_content_format=args.chat_template_content_format, | ||
| return_tokens_as_token_ids=args.return_tokens_as_token_ids, | ||
| enable_auto_tools=args.enable_auto_tool_choice, | ||
|
|
@@ -963,6 +965,7 @@ async def init_app_state( | |
| engine_client, | ||
| model_config, | ||
| state.openai_serving_models, | ||
| lora_cache_dir=args.lora_cache_dir, | ||
| request_logger=request_logger, | ||
| return_tokens_as_token_ids=args.return_tokens_as_token_ids, | ||
| ) if model_config.runner_type == "generate" else None | ||
|
|
@@ -972,6 +975,7 @@ async def init_app_state( | |
| state.openai_serving_models, | ||
| request_logger=request_logger, | ||
| chat_template=resolved_chat_template, | ||
| lora_cache_dir=args.lora_cache_dir, | ||
| chat_template_content_format=args.chat_template_content_format, | ||
| ) if model_config.runner_type == "pooling" else None | ||
| state.openai_serving_embedding = OpenAIServingEmbedding( | ||
|
|
@@ -980,24 +984,28 @@ async def init_app_state( | |
| state.openai_serving_models, | ||
| request_logger=request_logger, | ||
| chat_template=resolved_chat_template, | ||
| lora_cache_dir=args.lora_cache_dir, | ||
| chat_template_content_format=args.chat_template_content_format, | ||
| ) if model_config.task == "embed" else None | ||
| state.openai_serving_scores = ServingScores( | ||
| engine_client, | ||
| model_config, | ||
| state.openai_serving_models, | ||
| lora_cache_dir=args.lora_cache_dir, | ||
| request_logger=request_logger) if model_config.task in ( | ||
| "score", "embed", "pooling") else None | ||
| state.jinaai_serving_reranking = ServingScores( | ||
| engine_client, | ||
| model_config, | ||
| state.openai_serving_models, | ||
| lora_cache_dir=args.lora_cache_dir, | ||
| request_logger=request_logger | ||
| ) if model_config.task == "score" else None | ||
| state.openai_serving_tokenization = OpenAIServingTokenization( | ||
| engine_client, | ||
| model_config, | ||
| state.openai_serving_models, | ||
| lora_cache_dir=args.lora_cache_dir, | ||
| request_logger=request_logger, | ||
| chat_template=resolved_chat_template, | ||
| chat_template_content_format=args.chat_template_content_format, | ||
|
|
@@ -1006,6 +1014,7 @@ async def init_app_state( | |
| engine_client, | ||
| model_config, | ||
| state.openai_serving_models, | ||
| lora_cache_dir=args.lora_cache_dir, | ||
| request_logger=request_logger, | ||
| ) if model_config.runner_type == "transcription" else None | ||
| state.task = model_config.task | ||
|
|
@@ -1067,7 +1076,12 @@ def signal_handler(*_) -> None: | |
| app = build_app(args) | ||
|
|
||
| model_config = await engine_client.get_model_config() | ||
| await init_app_state(engine_client, model_config, app.state, args) | ||
| await init_app_state( | ||
| engine_client, | ||
| model_config, | ||
| app.state, | ||
| args, | ||
| ) | ||
|
|
||
| def _listen_addr(a: str) -> str: | ||
| if is_valid_ipv6_address(a): | ||
|
|
@@ -1113,5 +1127,6 @@ def _listen_addr(a: str) -> str: | |
| parser = make_arg_parser(parser) | ||
| args = parser.parse_args() | ||
| validate_parsed_serve_args(args) | ||
| validate_lora_cache_args(args) | ||
|
|
||
| uvloop.run(run_server(args)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this
==2? are we also counting the adapter added in a previous tests ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes the server is shared between all the tests in the module