-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[CORE] Adding support for insertion of soft-tuned prompts #4645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
91 commits
Select commit
Hold shift + click to select a range
04f262e
soft prompt support
96b4a1a
Run yapf and ruff
3131273
Multimodal fix
e9ff38b
correctness update
9f0a8ae
formatting
c2937d1
formatting
e43e89b
reverting to hasattr
a2b4fc3
adapter commons fix
3ebee19
minor fixes
629a684
formatting
a3ad6ac
reset_adapter
dcd7e88
bugfix
647a32d
reset_adapter fix
90d170c
peft dependencies
0fca895
fixing llava bug
d4e531c
typing fix
b7f8256
async engine update
449d988
batchwise processing
f28b66e
formatting
220deef
formatting yapf
01b9bb8
formatting again
2ea2796
enable_adapter paramter
96fe5ae
formatting
47725d9
adding test
638795a
adding test
f7d53b3
test case update
16f4037
formatting
f2f3cbc
resetting
0fc0c34
formatting
4eb47d6
formatting
e69842b
formatting
5c17480
Fix async engine
e62cbb5
Initial implementation of openai entrypoint
20fc56f
Merge branch 'main' into main
SwapnilDreams100 612d6c5
Fixes
894b9ba
async changes
00efe02
Merge branch 'main' into main
SwapnilDreams100 155ad76
formattign
042c9f1
formatting
0e46a06
adding dtype flexibility + pa lora refactor
3d14475
formatting
86e72de
formatting
41934cc
xpu compatibility
fdfec59
xpu compatibility
6b1f0e7
xpu compatibility
01bb713
xpu compatibility
3e5e147
Merge branch 'main' into main
SwapnilDreams100 d7312e2
formatting
454d45b
formatting + updating tests
409dba1
test changes
ab95ad7
cpu-gpu sync changes + adapter abstract changes
2faec61
formatting
f1a607c
Merge branch 'main' into main
SwapnilDreams100 6955301
rebase
2814aee
peft fix
0e45660
minor fix
d58e355
formatting
d700324
forward update
a5610a7
formatting
6b1c5ef
Merge branch 'main' into main
SwapnilDreams100 8b6e827
formatting
b83b6f0
spec decode fix
4babf0f
Merge branch 'main' into main
SwapnilDreams100 791ffbd
formatting
7226246
Merge branch 'main' into main
SwapnilDreams100 215947d
async executor
9ae47e8
formatting
3a2b545
formatting
bbaea88
formatting
34dbc8f
Merge branch 'main' into openai-entrypoint
9c2cc27
Merge branch 'main' into main
SwapnilDreams100 cdcea67
formatting
e771d43
max_prompt_adapter_token defaults + error messages
503adf4
updating tests
45c12ee
fix eager issue
9a73128
Merge branch 'main' into main
SwapnilDreams100 13d42c6
formatting
b2f3842
formatting
191f2c9
replacing numel w ndim for LoRA consistency
50514c3
Update tests/prompt_adapter/test_bloom.py
SwapnilDreams100 1217964
Update vllm/prompt_adapter/models.py
SwapnilDreams100 f9a5b4a
formatting
8545205
formatting
2d5c246
formatting
3da2777
docs update
9634b9d
Merge pull request #2 from g-eoj/openai-entrypoint
SwapnilDreams100 8279496
formatting
4336df1
formatting
77183d7
quick openapi fix
dd887f8
formatting
67a9f17
formatting
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| import pytest | ||
|
|
||
| import vllm | ||
| from vllm.prompt_adapter.request import PromptAdapterRequest | ||
|
|
||
| MODEL_PATH = "bigscience/bloomz-560m" | ||
| PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' | ||
|
|
||
|
|
||
| def do_sample(llm, pa_name: str, pa_id: int): | ||
|
|
||
| prompts = [ | ||
| "Tweet text : @nationalgridus I have no water and the bill is \ | ||
| current and paid. Can you do something about this? Label : ", | ||
| "Tweet text : @nationalgridus Looks good thanks! Label : " | ||
| ] | ||
| sampling_params = vllm.SamplingParams(temperature=0.0, | ||
| max_tokens=3, | ||
| stop_token_ids=[3]) | ||
|
|
||
| outputs = llm.generate(prompts, | ||
| sampling_params, | ||
| prompt_adapter_request=PromptAdapterRequest( | ||
| pa_name, pa_id, PA_PATH, 8) if pa_id else None) | ||
|
|
||
| # Print the outputs. | ||
| generated_texts = [] | ||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text.strip() | ||
| generated_texts.append(generated_text) | ||
| print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
| return generated_texts | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("enforce_eager", [True, False]) | ||
| def test_twitter_prompt_adapter(enforce_eager: bool): | ||
| llm = vllm.LLM(MODEL_PATH, | ||
| enforce_eager=enforce_eager, | ||
| enable_prompt_adapter=True, | ||
| max_prompt_adapter_token=8) | ||
|
|
||
| expected_output = ['complaint', 'no complaint'] | ||
|
|
||
| assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| from vllm import EngineArgs, LLMEngine, SamplingParams | ||
| from vllm.prompt_adapter.request import PromptAdapterRequest | ||
|
|
||
| MODEL_PATH = "bigscience/bloomz-560m" | ||
| pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' | ||
| pa_path2 = 'swapnilbp/angry_tweet_ptune' | ||
|
|
||
|
|
||
| def do_sample(engine): | ||
|
|
||
| prompts = [ | ||
| ("Tweet text: I have complaints! Label: ", | ||
| SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), | ||
| PromptAdapterRequest("hate_speech", 1, pa_path2, 8)), | ||
| ("Tweet text: I have no problems Label: ", | ||
| SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), | ||
| PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)), | ||
| ("Tweet text: I have complaints! Label: ", | ||
| SamplingParams(temperature=0.0, max_tokens=3), None), | ||
| ("Tweet text: I have no problems Label: ", | ||
| SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), | ||
| PromptAdapterRequest("complain", 3, pa_path, 8)), | ||
| ] | ||
|
|
||
| request_id = 0 | ||
| results = set() | ||
| while prompts or engine.has_unfinished_requests(): | ||
| if prompts: | ||
| prompt, sampling_params, pa_request = prompts.pop(0) | ||
| engine.add_request(str(request_id), | ||
| prompt, | ||
| sampling_params, | ||
| prompt_adapter_request=pa_request) | ||
| request_id += 1 | ||
|
|
||
| request_outputs = engine.step() | ||
|
|
||
| for request_output in request_outputs: | ||
| if request_output.finished: | ||
| results.add(request_output.outputs[0].text) | ||
| return results | ||
|
|
||
|
|
||
| def test_multi_prompt_adapters(): | ||
| engine_args = EngineArgs(model=MODEL_PATH, | ||
| max_prompt_adapters=3, | ||
| enable_prompt_adapter=True, | ||
| max_prompt_adapter_token=8) | ||
| engine = LLMEngine.from_engine_args(engine_args) | ||
| expected_output = { | ||
| ' quot;I', 'hate speech', 'no complaint', 'not hate speech' | ||
| } | ||
| assert do_sample(engine) == expected_output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| from huggingface_hub import snapshot_download | ||
|
|
||
| from vllm import EngineArgs, LLMEngine, SamplingParams | ||
| from vllm.lora.request import LoRARequest | ||
| from vllm.prompt_adapter.request import PromptAdapterRequest | ||
|
|
||
| MODEL_PATH = "meta-llama/Llama-2-7b-hf" | ||
| pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune") | ||
| lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") | ||
|
|
||
|
|
||
| def do_sample(engine): | ||
|
|
||
| prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501 | ||
|
|
||
| # first prompt with a prompt adapter and second without adapter | ||
| prompts = [ | ||
| (prompt_text, | ||
| SamplingParams(temperature=0.0, max_tokens=100, | ||
| stop=["[/assistant]"]), | ||
| PromptAdapterRequest("hate_speech", 1, pa_path, | ||
| 8), LoRARequest("sql_test", 1, lora_path)), | ||
| (prompt_text, | ||
| SamplingParams(temperature=0.0, max_tokens=100, | ||
| stop=["[/assistant]"]), None, | ||
| LoRARequest("sql_test", 1, lora_path)), | ||
| ] | ||
|
|
||
| request_id = 0 | ||
| results = set() | ||
| while prompts or engine.has_unfinished_requests(): | ||
| if prompts: | ||
| prompt, sampling_params, pa_request, lora_request = prompts.pop(0) | ||
| engine.add_request(str(request_id), | ||
| prompt, | ||
| sampling_params, | ||
| prompt_adapter_request=pa_request, | ||
| lora_request=lora_request) | ||
| request_id += 1 | ||
|
|
||
| request_outputs = engine.step() | ||
|
|
||
| for request_output in request_outputs: | ||
| if request_output.finished: | ||
| results.add(request_output.outputs[0].text) | ||
| return results | ||
|
|
||
|
|
||
| def test_lora_prompt_adapter(): | ||
| engine_args = EngineArgs(model=MODEL_PATH, | ||
| enable_prompt_adapter=True, | ||
| enable_lora=True, | ||
| max_num_seqs=60, | ||
| max_prompt_adapter_token=8) | ||
| engine = LLMEngine.from_engine_args(engine_args) | ||
| result = do_sample(engine) | ||
|
|
||
| expected_output = { | ||
| " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501 | ||
| } | ||
| assert result == expected_output | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| from dataclasses import dataclass | ||
| from typing import Tuple | ||
|
|
||
|
|
||
| @dataclass | ||
| class AdapterMapping: | ||
| # Per every token in input_ids: | ||
| index_mapping: Tuple[int, ...] | ||
| # Per sampled token: | ||
| prompt_mapping: Tuple[int, ...] | ||
|
|
||
| def __post_init__(self): | ||
| self.index_mapping = tuple(self.index_mapping) | ||
| self.prompt_mapping = tuple(self.prompt_mapping) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import Any, Callable, Dict, Hashable, Optional, TypeVar | ||
|
|
||
| from torch import nn | ||
|
|
||
| from vllm.logger import init_logger | ||
| from vllm.utils import LRUCache | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class AdapterModel(ABC): | ||
|
|
||
| def __init__(self, model_id=None): | ||
| self.id = model_id | ||
|
|
||
| @abstractmethod | ||
| def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): | ||
| # Common initialization code | ||
| # Load weights or embeddings from local checkpoint | ||
| raise NotImplementedError("Subclasses must implement this method.") | ||
|
|
||
|
|
||
| T = TypeVar('T') | ||
|
|
||
|
|
||
| class AdapterLRUCache(LRUCache[T]): | ||
SwapnilDreams100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], | ||
| None]): | ||
| super().__init__(capacity) | ||
| self.deactivate_fn = deactivate_fn | ||
|
|
||
| def _on_remove(self, key: Hashable, value: T): | ||
| logger.debug("Removing adapter int id: %d", key) | ||
| self.deactivate_fn(key) | ||
| return super()._on_remove(key, value) | ||
|
|
||
|
|
||
| class AdapterModelManager(ABC): | ||
|
|
||
| def __init__( | ||
| self, | ||
| model: nn.Module, | ||
| ): | ||
| """Create a AdapterModelManager and adapter for a given model. | ||
| Args: | ||
| model: the model to be adapted. | ||
| """ | ||
| self.model: nn.Module = model | ||
| self._registered_adapters: Dict[int, Any] = {} | ||
| # Dict instead of a Set for compatibility with LRUCache. | ||
| self._active_adapters: Dict[int, None] = {} | ||
| self.adapter_type = 'Adapter' | ||
| self._last_mapping = None | ||
|
|
||
| def __len__(self) -> int: | ||
| return len(self._registered_adapters) | ||
|
|
||
| @property | ||
| @abstractmethod | ||
| def adapter_slots(self): | ||
| ... | ||
|
|
||
| @property | ||
| @abstractmethod | ||
| def capacity(self): | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def activate_adapter(self, adapter_id: int) -> bool: | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def deactivate_adapter(self, adapter_id: int) -> bool: | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def add_adapter(self, adapter: Any) -> bool: | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def set_adapter_mapping(self, mapping: Any) -> None: | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def remove_adapter(self, adapter_id: int) -> bool: | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def remove_all_adapters(self): | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def get_adapter(self, adapter_id: int) -> Optional[Any]: | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def list_adapters(self) -> Dict[int, Any]: | ||
| ... | ||
|
|
||
| @abstractmethod | ||
| def pin_adapter(self, adapter_id: int) -> bool: | ||
| ... | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| from abc import abstractmethod | ||
| from dataclasses import dataclass | ||
|
|
||
|
|
||
| @dataclass | ||
| class AdapterRequest: | ||
| """ | ||
| Base class for adapter requests. | ||
| """ | ||
|
|
||
SwapnilDreams100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| @property | ||
| @abstractmethod | ||
| def adapter_id(self): | ||
| ... | ||
|
|
||
| def __post_init__(self): | ||
| if self.adapter_id < 1: | ||
| raise ValueError(f"id must be > 0, got {self.adapter_id}") | ||
|
|
||
| def __eq__(self, value: object) -> bool: | ||
| return isinstance( | ||
| value, self.__class__) and self.adapter_id == value.adapter_id | ||
|
|
||
| def __hash__(self) -> int: | ||
| return hash(self.adapter_id) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.