From 5419cc83f953916c9e155d04bed08c26b6ce0e4a Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 11 Aug 2024 00:04:24 +0800 Subject: [PATCH 1/6] done --- tests/lora/test_layer_variation.py | 106 ---------------------------- tests/lora/test_punica_sizes.py | 2 +- tests/lora/test_punica_variation.py | 23 +----- 3 files changed, 3 insertions(+), 128 deletions(-) delete mode 100644 tests/lora/test_layer_variation.py diff --git a/tests/lora/test_layer_variation.py b/tests/lora/test_layer_variation.py deleted file mode 100644 index ec9776b77df7..000000000000 --- a/tests/lora/test_layer_variation.py +++ /dev/null @@ -1,106 +0,0 @@ -import tempfile -from random import sample -from typing import List, Optional - -import peft -import pytest -from transformers import AutoModelForCausalLM - -import vllm -from vllm.lora.request import LoRARequest - -from .conftest import cleanup - -MODEL_PATH = "Felladrin/Llama-68M-Chat-v1" -PROMPTS = [ - "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", # noqa: E501 - "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", # noqa: E501 - "[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]", # noqa: E501 -] - - -def get_lora_model(model_id: str, target_modules: List[str], rank: int): - model = AutoModelForCausalLM.from_pretrained(model_id) - lora_config = peft.tuners.lora.LoraConfig(target_modules, rank) - lora_model = peft.PeftModel(model, lora_config) - return lora_model - - -def do_sample(llm: vllm.LLM, - lora_path: Optional[str] = None, - lora_id: Optional[int] = None, - logprobs: int = 0, - n_tokens: int = 256): - prompts = PROMPTS - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=n_tokens, - logprobs=logprobs, - stop=["[/assistant]"]) - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) - # Print the outputs. - generated_texts: List[str] = [] - generated_logprobs: List[List[List[int]]] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - generated_logprobs.append([ - list(logprob.keys()) for out in output.outputs - for logprob in out.logprobs - ]) - return generated_logprobs if logprobs else generated_texts - - -SUPPORTED_MODULES = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", - "lm_head" -] -TARGET_MODULES_LIST = [] -for length in range(2, 6): - TARGET_MODULES_LIST.extend( - [sample(SUPPORTED_MODULES, length) for _ in range(3)]) - - -# Test the correctness when layer and rank are varied -# step 1: init a base model and serve with LoRA to get the reference results -# step 2: merge the same LoRA to the base model, serve the merged model -# step 3: compare the results from step 1 and step 2 -@pytest.mark.parametrize("tp_size", [1]) -@pytest.mark.parametrize("target_modules", TARGET_MODULES_LIST) -@pytest.mark.parametrize("rank", [8, 16, 32, 64]) -def test_layer_variation_correctness(tp_size, target_modules, rank): - llm = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=tp_size, - worker_use_ray=True) - model = get_lora_model(MODEL_PATH, target_modules, rank) - with tempfile.TemporaryDirectory() as tmpdir: - model.save_pretrained(tmpdir) - merged_probs = do_sample(llm, tmpdir, 1, logprobs=5, n_tokens=32) - del llm - cleanup() - reference_id_sets = [set(prob[0]) for prob in merged_probs] - - model = get_lora_model(MODEL_PATH, target_modules, rank) - with tempfile.TemporaryDirectory() as tmpdir: - merged_model = model.merge_and_unload() - merged_model.save_pretrained(tmpdir) - llm = vllm.LLM(tmpdir, - tokenizer=MODEL_PATH, - enable_lora=False, - max_num_seqs=16, - tensor_parallel_size=tp_size, - worker_use_ray=True) - probs = do_sample(llm, logprobs=5, n_tokens=32) - del llm - cleanup() - # verify the top-5 tokens are identical for each token - id_sets = [set(prob[0]) for prob in probs] - assert id_sets == reference_id_sets diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index c052568dc2e3..c36fb3afb0cc 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -98,7 +98,7 @@ 128256, ] #The size of TP -divisibility = [1, 2, 4, 8, 16, 32, 64] +divisibility = [1, 2, 8, 16, 64] all_hidden_size = [] for div in divisibility: diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 5bf3f72e7d97..d026e34878e0 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -20,10 +20,10 @@ from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) -HIDDEN_SIZES = [3424, 4096, 4097] +HIDDEN_SIZES = [4097] BATCHES = [1, 4, 16, 32] -NUM_LORA = [1, 4, 8, 16, 32, 64, 128] +NUM_LORA = [1, 8, 32, 128] DTYPES = [torch.float16, torch.bfloat16] MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256] SCALES = [0.5] @@ -321,22 +321,3 @@ def test_punica_expand_nslices( slice_offset += hidden_size assert_close(our_outputs, ref_outputs) - - -if __name__ == "__main__": - from itertools import product - - lst = list( - product( - BATCHES, - NUM_LORA, - MAX_RANKS, - [1.0], - [torch.float16], - ["expand"], - SEED, - CUDA_DEVICES, - )) - for ele in lst: - test_punica_bgmv(*ele) - print(f"{ele},pass") From be295695a576440607a5a1a6683475db98a5c3bb Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 16 Aug 2024 15:57:04 +0800 Subject: [PATCH 2/6] Ehance punica assert --- tests/lora/test_punica_sizes.py | 5 +++++ tests/lora/test_punica_variation.py | 5 +++++ vllm/lora/ops/bgmv_expand_slice.py | 2 +- vllm/lora/ops/bgmv_shrink.py | 2 +- vllm/lora/ops/sgmv_expand.py | 16 ++++++++------ vllm/lora/ops/sgmv_expand_slice.py | 18 ++++++++++------ vllm/lora/ops/sgmv_shrink.py | 16 ++++++++------ vllm/lora/punica.py | 33 ++++++++++++++++------------- 8 files changed, 61 insertions(+), 36 deletions(-) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index c36fb3afb0cc..71bca23e5b9f 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -172,6 +172,7 @@ def test_punica_sgmv( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -186,6 +187,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, scaling, ) else: @@ -198,6 +200,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, add_inputs=True, ) ref_torch_groupgemm( @@ -355,6 +358,7 @@ def test_punica_expand_nslices( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -374,6 +378,7 @@ def test_punica_expand_nslices( max_seq_length, slice_offset, hidden_size, + token_nums, add_inputs=True, ) else: diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index d026e34878e0..f25f6c9a335e 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -87,6 +87,7 @@ def test_punica_sgmv( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -101,6 +102,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, scaling, ) else: @@ -113,6 +115,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, add_inputs=True, ) ref_torch_groupgemm( @@ -270,6 +273,7 @@ def test_punica_expand_nslices( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -289,6 +293,7 @@ def test_punica_expand_nslices( max_seq_length, slice_offset, hidden_size, + token_nums, add_inputs=True, ) else: diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index fa6571074f3a..3c154bb50c66 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -107,7 +107,7 @@ def bgmv_expand_slice( lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch, An index of -1 means no lora should be applied. - slice_offst (int): output_tensor's offst + slice_offset (int): output_tensor's offset slice_size (int): current output_tensor's size batches (int): batch size add_inputs (bool, optional): Defaults to False. diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/bgmv_shrink.py index e69d33078f5a..e49a8fbf4d21 100644 --- a/vllm/lora/ops/bgmv_shrink.py +++ b/vllm/lora/ops/bgmv_shrink.py @@ -95,7 +95,7 @@ def bgmv_shrink( corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - scaling (float): Scaling factor. + scaling (float): Scaling factor override_config (Optional[Dict[str, int]], optional): Defaults to None. Triton grid config """ diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 459049546909..dcab99bfbab1 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -106,6 +106,7 @@ def sgmv_expand( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, add_inputs: bool = False, ): """ @@ -115,17 +116,19 @@ def sgmv_expand( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences - in the batch - add_inputs (bool, optional): Defaults to False. adds the final lora + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + add_inputs (bool, optional): Defaults to False, adds the final lora results to the output. """ @@ -134,6 +137,7 @@ def sgmv_expand( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_b_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index ff3bcda071b8..16f9d1cc8829 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -114,6 +114,7 @@ def sgmv_expand_slice( max_seq_length: int, slice_offset: int, slice_size: int, + token_nums: int, add_inputs: bool = False, ): """_summary_ @@ -124,20 +125,22 @@ def sgmv_expand_slice( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences + max_seq_length (int): The max sequence lengths of the sequences in the batch - slice_offst (int): output_tensor's offst + slice_offset (int): output_tensor's offset slice_size (int): current output_tensor's size - add_inputs (bool, optional): Defaults to False. adds the final lora - results to the output.. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] @@ -145,6 +148,7 @@ def sgmv_expand_slice( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_b_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index 8ab049989abe..7c978b0a3539 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -110,6 +110,7 @@ def sgmv_shrink( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, scaling: float, ): """ @@ -120,17 +121,19 @@ def sgmv_shrink( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences - in the batch - scaling (float): Scaling factor. + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + scaling (float): Scaling factor. """ assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype in [torch.float16, torch.bfloat16] @@ -138,6 +141,7 @@ def sgmv_shrink( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_a_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 6d5c83429996..f9841e5aaffb 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -178,7 +178,7 @@ def convert_mapping( class PunicaWrapper: """ PunicaWrapper is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the punica kernel. """ @@ -216,6 +216,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, dtype=torch.long, device=device) self.max_length: int = 0 + self.token_nums: int = 0 self.batch_size: int = -1 self.is_prefill = False self.no_lora = False @@ -276,7 +277,7 @@ def _update_base_metadata( long_lora_offsets_tensor) else: self._long_lora_indices.zero_() - + self.token_nums = base_indices.sum().item() self.indices_len[:] = indices_len def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: @@ -295,21 +296,23 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: @property def prefill_metadata( - self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + self + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: """ This property provides a convenient way to access the necessary metadata for prefill-related kernel computations. - 1. seq_start_locs: Tensor of sequence start positions - 2. seq_lengths: Tensor of sequence lengths + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. 3. lora_indices_per_batch: Tensor of lora indices, and an index of -1 means no lora should be applied. - 4. batch_size: batch size after clustering identical lora indices - 5. max_length: The maximum sequence length in the batch + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. """ return (self._seq_start_locs[:self.batch_size], self._seq_lengths[:self.batch_size], self._lora_indices_per_batch[:self.batch_size], - self.batch_size, self.max_length) + self.batch_size, self.max_length, self.token_nums) @property def token_lora_indices(self) -> torch.Tensor: @@ -324,7 +327,7 @@ def token_lora_indices(self) -> torch.Tensor: def sampler_indices(self) -> torch.Tensor: """ This property is used to access the lora indices specifically for - LogitsProcessorWithLoRA + LogitsProcessorWithLoRA. """ sampler_indices_len = self.indices_len[1] return self._sampler_indices[:sampler_indices_len] @@ -332,7 +335,7 @@ def sampler_indices(self) -> torch.Tensor: @property def sampler_indices_padded(self) -> torch.Tensor: """ - This property provides access to padded sampler indices + This property provides access to padded sampler indices. """ indices_padded_len = self.indices_len[2] return self._sampler_indices_padded[:indices_padded_len] @@ -341,7 +344,7 @@ def sampler_indices_padded(self) -> torch.Tensor: def embeddings_indices(self) -> torch.Tensor: """ This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA + specifically for VocabParallelEmbeddingWithLoRA. """ embeddings_indices_len = self.indices_len[3] return self._embeddings_indices[:, :embeddings_indices_len] @@ -350,7 +353,7 @@ def embeddings_indices(self) -> torch.Tensor: def long_lora_indices(self) -> torch.Tensor: """ This property provides access to the indices used for long context - lora, specifically for LinearScalingRotaryEmbeddingWithLora + lora, specifically for LinearScalingRotaryEmbeddingWithLora. """ long_lora_len = self.indices_len[4] return self._long_lora_indices[:long_lora_len] @@ -518,9 +521,9 @@ def add_lora(self, ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. - x (torch.Tensor): Input tensor - wa_t_all (torch.Tensor): lora_a's weight - wb_t_all (torch.Tensor): lora_b's weight + x (torch.Tensor): Input tensor. + wa_t_all (torch.Tensor): lora_a's weight. + wb_t_all (torch.Tensor): lora_b's weight. scale (float): Scaling factor. y_offset (Optional[int], optional): Offset to apply to the starting column of y. From eba3666f1f66d8b70b22f08f40cd8bf03c216a4b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 16 Aug 2024 15:59:11 +0800 Subject: [PATCH 3/6] fix typo --- vllm/lora/punica.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index f9841e5aaffb..d6821db90bad 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -521,13 +521,13 @@ def add_lora(self, ).squeeze(0) Args: y (torch.Tensor): Output tensor. Will be changed in-place. - x (torch.Tensor): Input tensor. - wa_t_all (torch.Tensor): lora_a's weight. - wb_t_all (torch.Tensor): lora_b's weight. + x (torch.Tensor): Input tensor + wa_t_all (torch.Tensor): lora_a's weight + wb_t_all (torch.Tensor): lora_b's weight scale (float): Scaling factor. y_offset (Optional[int], optional): Offset to apply to the starting column of y. - y_slice_size (Optional[int], optional): Size of the y column slice.. + y_slice_size (Optional[int], optional): Size of the y column slice. buffer (Optional[torch.Tensor], optional): Defaults to None. """ y_org = y From b6bca7a22ff89f8ab032b57758286c9bb91670a3 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 16 Aug 2024 17:34:56 +0800 Subject: [PATCH 4/6] Fix token_nums bug --- vllm/lora/punica.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index d6821db90bad..15972ab6459f 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -27,7 +27,7 @@ def compute_meta( token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: """ Get the information required for the sgmv kernel. With the features: 1. If consecutive requests in the batch use the same LoRA, this function @@ -36,14 +36,14 @@ def compute_meta( 2. At the beginning of each prefill stage inference, recalculations are needed based on the input, but only once. """ - + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( token_lora_tensor, return_counts=True) cum_result = torch.cumsum(seq_length_tensor, dim=0) b_seq_start_tensor = torch.zeros_like(seq_length_tensor) b_seq_start_tensor[1:].copy_(cum_result[:-1]) max_length = seq_length_tensor.max().item() - + token_nums = seq_length_tensor.sum().item() batch_size = lora_indices_tensor.size(0) no_lora = False # -1 means no lora should be applied. Use `no_lora` to determine whether @@ -52,7 +52,7 @@ def compute_meta( if batch_size == 1 and lora_indices_tensor == -1: no_lora = True return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) + batch_size, max_length, token_nums, no_lora) # TODO see if this can be vectorized @@ -277,13 +277,13 @@ def _update_base_metadata( long_lora_offsets_tensor) else: self._long_lora_indices.zero_() - self.token_nums = base_indices.sum().item() self.indices_len[:] = indices_len def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) = compute_meta(token_lora_tensor) + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( b_seq_start_tensor) @@ -292,6 +292,7 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: lora_indices_tensor) self.batch_size = batch_size self.max_length = max_length + self.token_nums = token_nums self.no_lora = no_lora @property From 8a914bc65f23bf7db8a91fa9aa0c5841e2992bf0 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 17 Aug 2024 00:04:46 +0800 Subject: [PATCH 5/6] Fix token_nums bug again --- tests/lora/test_punica_sizes.py | 2 +- vllm/lora/ops/sgmv_expand_slice.py | 6 +++--- vllm/lora/punica.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 71bca23e5b9f..2abba06dfb9e 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -376,9 +376,9 @@ def test_punica_expand_nslices( lora_indices_tensor, batches, max_seq_length, + token_nums, slice_offset, hidden_size, - token_nums, add_inputs=True, ) else: diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 16f9d1cc8829..4bb22b51ec37 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -112,9 +112,9 @@ def sgmv_expand_slice( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, slice_offset: int, slice_size: int, - token_nums: int, add_inputs: bool = False, ): """_summary_ @@ -135,10 +135,10 @@ def sgmv_expand_slice( batches (int): batch size max_seq_length (int): The max sequence lengths of the sequences in the batch - slice_offset (int): output_tensor's offset - slice_size (int): current output_tensor's size token_nums (int): The token numbers in the batch. Used to verify if the token numbers in the inputs matches the one in the metadata. + slice_offset (int): output_tensor's offset + slice_size (int): current output_tensor's size add_inputs (bool, optional): Defaults to False, adds the final lora results to the output. """ diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 15972ab6459f..5033ce412692 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -36,7 +36,7 @@ def compute_meta( 2. At the beginning of each prefill stage inference, recalculations are needed based on the input, but only once. """ - + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( token_lora_tensor, return_counts=True) cum_result = torch.cumsum(seq_length_tensor, dim=0) From eb497df703cccb3ad4e86c45e088d20fd8f6b310 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 17 Aug 2024 00:50:58 +0800 Subject: [PATCH 6/6] Address unit test error --- tests/lora/test_punica_variation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index f25f6c9a335e..1e2238dc1acd 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -291,9 +291,9 @@ def test_punica_expand_nslices( lora_indices_tensor, batches, max_seq_length, + token_nums, slice_offset, hidden_size, - token_nums, add_inputs=True, ) else: