From 8be09bb90f9ce20517c7431e2f318509cf79f729 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 6 Mar 2025 15:57:28 +0000 Subject: [PATCH 1/2] Done Signed-off-by: Jee Jee Li --- tests/lora/conftest.py | 24 ----------------- tests/lora/test_jamba.py | 54 --------------------------------------- tests/lora/test_layers.py | 3 +++ vllm/lora/layers.py | 37 +++++++++++++++++++++++---- 4 files changed, 35 insertions(+), 83 deletions(-) delete mode 100644 tests/lora/test_jamba.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index dd14abff630c..f3b545670b88 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch import pytest -import safetensors import torch import torch.nn as nn from huggingface_hub import snapshot_download @@ -191,29 +190,6 @@ def mixtral_lora_files_all_target_modules(): return snapshot_download(repo_id="dyang415/mixtral-lora-v0") -@pytest.fixture(scope="session") -def jamba_lora_files(): - # some of the adapters have unnecessary weights for serving, - # hence we remove them - def remove_unnecessary_weights(path): - lora_path = f"{adapter_path}/adapter_model.safetensors" - tensors = safetensors.torch.load_file(lora_path) - nonlora_keys = [] - for k in list(tensors.keys()): - if "lora" not in k: - nonlora_keys.append(k) - for k in nonlora_keys: - del tensors[k] - safetensors.torch.save_file(tensors, lora_path) - - adapter_path = snapshot_download( - repo_id= - "hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora") - - remove_unnecessary_weights(adapter_path) - return adapter_path - - @pytest.fixture(scope="session") def gemma_lora_files(): return snapshot_download(repo_id="wskwon/gemma-7b-test-lora") diff --git a/tests/lora/test_jamba.py b/tests/lora/test_jamba.py deleted file mode 100644 index 885851880b59..000000000000 --- a/tests/lora/test_jamba.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import torch - -import vllm -from vllm.lora.request import LoRARequest - -MODEL_PATH = "ai21labs/AI21-Jamba-1.5-Mini" - -MAX_TOKENS = 40 - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, - prompts: list[str]) -> list[str]: - - sampling_params = vllm.SamplingParams(temperature=0, max_tokens=MAX_TOKENS) - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) - # Print the outputs. - generated_texts: list[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -@pytest.mark.parametrize("tp_size", [4]) -def test_jamba_lora(jamba_lora_files, tp_size): - """Original test, the LoRA model has the common target modules, not all""" - if torch.cuda.device_count() < tp_size: - pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") - - prompts = ["Write a story about a sheep and a goat."] - - llm = vllm.LLM( - MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - distributed_executor_backend="ray", - tensor_parallel_size=tp_size, - ) - - expected_jamba_output = [ - """Once upon a time, in a lush green meadow, there lived a sheep named Clara and a goat named Billy. Clara was a gentle creature, always nibbling on the soft grass and humming""" # noqa: E501 - ] - assert do_sample(llm, jamba_lora_files, lora_id=1, - prompts=prompts) == expected_jamba_output diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 3507d0121212..3f46ddbcedff 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -632,6 +632,7 @@ def create_random_linear_replicated_layer(): id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_random_linear_replicated_layer() + assert torch.equal(linear.weight,lora_linear.weight) lora_linear.set_mapping(punica_wrapper) lora_dict, _ = populate_loras( id_to_index, @@ -757,6 +758,7 @@ def create_random_linear_parallel_layer(): id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_random_linear_parallel_layer() + assert torch.equal(linear.weight,lora_linear.weight) lora_linear.set_mapping(punica_wrapper) lora_dict, _ = populate_loras( id_to_index, @@ -904,6 +906,7 @@ class FakeConfig: id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_column_parallel_packed_layer() + assert torch.equal(linear.weight,lora_linear.weight) lora_linear.set_mapping(punica_wrapper) lora_dict, sublora_dict = populate_loras( id_to_index, diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index ff1b6501d1f1..1c1f76702ddb 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -274,6 +274,10 @@ def can_replace_layer( ) -> bool: return type(source_layer) is VocabParallelEmbedding + @property + def weight(self): + return self.base_layer.weight + class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): @@ -409,6 +413,34 @@ def apply(self, self.output_slices) return output + @property + def weight(self) -> torch.Tensor: + + # unquantizedLinear + if hasattr(self.base_layer, "weight"): + return self.base_layer.weight + # Compressed Tensor + elif hasattr(self.base_layer, "weight_packed"): + return self.base_layer.weight_packed + # GPTQ/AWQ + elif hasattr(self.base_layer, "qweight"): + return self.base_layer.qweight + # marlin + elif hasattr(self.base_layer, "B"): + return self.base_layer.B + # HQQ marlin + elif hasattr(self.base_layer, "W_q"): + return self.base_layer.W_q + else: + raise ValueError(f"Unsupported base layer: {self.base_layer}") + + @property + def bias(self) -> Optional[torch.Tensor]: + if hasattr(self.base_layer, "bias"): + return self.base_layer.bias + else: + return None + class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): @@ -902,11 +934,6 @@ def forward( return output, output_bias - @property - def weight(self): - return (self.base_layer.weight if hasattr(self.base_layer, "weight") - else self.base_layer.qweight) - @classmethod @_not_fully_sharded_can_replace def can_replace_layer( From bd9c438e8c44febfd671238b45ff61c6ff597556 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 6 Mar 2025 16:34:44 +0000 Subject: [PATCH 2/2] format Signed-off-by: Jee Jee Li --- tests/lora/test_layers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 3f46ddbcedff..428a1c71d098 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -632,7 +632,7 @@ def create_random_linear_replicated_layer(): id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_random_linear_replicated_layer() - assert torch.equal(linear.weight,lora_linear.weight) + assert torch.equal(linear.weight, lora_linear.weight) lora_linear.set_mapping(punica_wrapper) lora_dict, _ = populate_loras( id_to_index, @@ -758,7 +758,7 @@ def create_random_linear_parallel_layer(): id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_random_linear_parallel_layer() - assert torch.equal(linear.weight,lora_linear.weight) + assert torch.equal(linear.weight, lora_linear.weight) lora_linear.set_mapping(punica_wrapper) lora_dict, _ = populate_loras( id_to_index, @@ -906,7 +906,7 @@ class FakeConfig: id_to_index = get_random_id_to_index(num_loras, max_loras) linear, lora_linear = create_column_parallel_packed_layer() - assert torch.equal(linear.weight,lora_linear.weight) + assert torch.equal(linear.weight, lora_linear.weight) lora_linear.set_mapping(punica_wrapper) lora_dict, sublora_dict = populate_loras( id_to_index,