From d993de90ce34313b8d70ab8e368acc49a0d88d01 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 20 Nov 2024 19:20:21 +0000 Subject: [PATCH 001/186] Added non-triton SGMV and BGMV ops (not kernels yet) Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 1 + vllm/lora/ops/__init__.py | 0 vllm/lora/ops/xla/lora_ops.py | 119 ++++++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+) delete mode 100644 vllm/lora/ops/__init__.py create mode 100644 vllm/lora/ops/xla/lora_ops.py diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 5a4d991da1b5..7e0173936ac6 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1107,6 +1107,7 @@ def _get_logits( self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits + print("punica", logits.dtype) # LogitsProcessorWithLoRA always using bgmv self.punica_wrapper.add_lora_logits(logits, hidden_states, self.lora_a_stacked, diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/vllm/lora/ops/xla/lora_ops.py b/vllm/lora/ops/xla/lora_ops.py new file mode 100644 index 000000000000..8ad32dd4a77b --- /dev/null +++ b/vllm/lora/ops/xla/lora_ops.py @@ -0,0 +1,119 @@ +import torch + +def sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) + + bgmv_expand( + inputs, + lora_b_weights, + output_tensor, + exploded_indices, + add_inputs + ) + + +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True +): + selected_loras = lora_b_weights[lora_indices_tensor].squeeze() + inputs = inputs.to(dtype=torch.float16) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:] += outputs[:] + else: + output_tensor[:] = outputs[:] + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) + + bgmv_shrink( + inputs, + lora_a_weights, + output_tensor, + exploded_indices, + scaling + ) + +def bgmv_shrink( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0 +): + selected_loras = lora_b_weights[lora_indices_tensor].squeeze() + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + output_tensor[:] = scaling * outputs[:] + +def sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) + + bgmv_expand_slice( + inputs, + lora_b_weights, + output_tensor, + exploded_indices, + slice_offset, + slice_size, + add_inputs + ) + + +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True +): + selected_loras = lora_b_weights[lora_indices_tensor].squeeze() + inputs = inputs.to(dtype=torch.float16) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:, slice_offset:slice_offset+slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset+slice_size] = outputs[:] \ No newline at end of file From 4f816ed58e2a714a46beb8c0087511948ec98186 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 20 Nov 2024 19:21:03 +0000 Subject: [PATCH 002/186] Made a copy of the layer tests for the TPU. TODO: DRY it out Signed-off-by: Akshat Tripathi --- tests/lora/conftest.py | 2 +- tests/lora/test_layers_tpu.py | 1220 +++++++++++++++++++++++++++++++++ 2 files changed, 1221 insertions(+), 1 deletion(-) create mode 100644 tests/lora/test_layers_tpu.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index dd14abff630c..ee0807386391 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -71,7 +71,7 @@ def dist_init(): temp_file = tempfile.mkstemp()[1] backend = "nccl" - if current_platform.is_cpu(): + if current_platform.is_cpu() or current_platform.is_tpu(): backend = "gloo" init_distributed_environment(world_size=1, diff --git a/tests/lora/test_layers_tpu.py b/tests/lora/test_layers_tpu.py new file mode 100644 index 000000000000..29f732c621af --- /dev/null +++ b/tests/lora/test_layers_tpu.py @@ -0,0 +1,1220 @@ +import random +from copy import deepcopy +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple +from unittest.mock import patch + +import pytest +import torch +import torch.nn.functional as F + +from vllm.config import LoRAConfig +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLora, + LogitsProcessorWithLoRA, LoRAMapping, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) +# yapf: enable +from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, + PackedLoRALayerWeights) +from vllm.lora.punica import PunicaWrapper +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) +from vllm.model_executor.utils import set_random_seed +from vllm.platforms import current_platform + +from .utils import DummyLoRAManager + +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.float32: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} +TPU_DEVICES = [ + f"xla:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +# We will launch different triton kernels between the prefill and decode +# stages, so we need to verify this. prefill stage(True) or decode stage(False) +STAGES = [True, False] + + +def get_random_id_to_index(num_loras: int, + num_slots: int, + log: bool = True) -> List[Optional[int]]: + """Creates a random lora_id_to_index mapping. + + Args: + num_loras: The number of active loras in the mapping. + num_slots: The number of slots in the mapping. Must be larger + than num_loras. + log: Whether to log the output. + """ + + if num_loras > num_slots: + raise ValueError( + f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " + "num_loras must be less than or equal to num_slots.") + + slots: List[Optional[int]] = [None] * num_slots + random_slot_selections = (torch.randperm(num_slots, device="cpu")[:num_loras]).tolist() + for lora_id, slot_idx in enumerate(random_slot_selections, start=1): + slots[slot_idx] = lora_id + + if log: + print(f"Created lora_id_to_index mapping: {slots}.") + + return slots + + +def populate_loras( + id_to_index: List[Optional[int]], + layer: BaseLayerWithLoRA, + layer_weights: torch.Tensor, + generate_embeddings_tensor: int = 0, + repeats: int = 1, +) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]: + """This method populates the lora layers with lora weights. + + Args: + id_to_index: a list of lora ids. The index of the lora id + represents which memory slot the lora matrices are + stored in. A None value indicates a free slot. + layer: the LoRAlayer to populate. + layer_weights: the PyTorch tensor containing the layer's + weights. + generate_embeddings_tensor: whether to generate an + embeddings tensor for each LoRA. + repeats: must only be set for column parallel packed + layers. Indicates the number of loras to compose + together to create a single lora layer. + """ + + # Dictionary that maps the lora ID to the + # corresponding lora weights. + lora_dict: Dict[int, LoRALayerWeights] = dict() + + # Dictionary that maps the lora ID to the + # corresponding subloras. + sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() + + for slot_idx, lora_id in enumerate(id_to_index): + if lora_id is not None: + subloras: List[LoRALayerWeights] = [] + sublora_len = layer_weights.shape[0] // repeats + for i in range(repeats): + sublora = DummyLoRAManager( + layer_weights.device).init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[:, (sublora_len * + i):(sublora_len * (i + 1))] + sublora.optimize() + subloras.append(sublora) + + lora = PackedLoRALayerWeights.pack( + subloras) if repeats > 1 else subloras[0] + + layer.set_lora( + slot_idx, + lora_a=lora.lora_a, + lora_b=lora.lora_b, + embeddings_tensor=lora.embeddings_tensor, + ) + + lora_dict[lora_id] = lora + sublora_dict[lora_id] = subloras + + return lora_dict, sublora_dict + + +def create_random_inputs( + active_lora_ids: List[int], + num_inputs: int, + input_size: Tuple[int, ...], + input_range: Tuple[float, float], + input_type: torch.dtype = torch.int, + device: torch.device = "xla" +) -> Tuple[List[torch.Tensor], List[int], List[int]]: + """Creates random inputs. + + Args: + active_lora_ids: lora IDs of active lora weights. + num_inputs: the number of inputs to create. + input_size: the size of each individual input. + input_range: the range of values to include in the input. + input_range[0] <= possible input values < input_range[1] + input_type: the type of values in the input. + """ + + low, high = input_range + + inputs: List[torch.Tensor] = [] + index_mapping: List[int] = [] + prompt_mapping: List[int] = [] + + for _ in range(num_inputs): + if input_type == torch.int: + inputs.append( + torch.randint(low=int(low), + high=int(high), + size=input_size, + device=device)) + else: + inputs.append( + torch.rand(size=input_size, dtype=input_type, device=device) * + high + low) + + lora_id = random.choice(active_lora_ids) + index_mapping += [lora_id] * input_size[0] + prompt_mapping += [lora_id] + + return inputs, index_mapping, prompt_mapping + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("device", TPU_DEVICES) +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +@pytest.mark.parametrize("stage", STAGES) +def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: + torch.set_default_device(device) + max_loras = 8 + punica_wrapper = PunicaWrapper(8192, 256, device) + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(vocab_size, 256) + embedding.weight.data = torch.rand_like(embedding.weight.data) + embedding.weight.data[vocab_size:, :] = 0 + lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + embedding, lora_embedding = create_random_embedding_layer() + lora_embedding.set_mapping(punica_wrapper) + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=embedding.weight.T, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, vocab_size), + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size, + lora_config.lora_extra_vocab_size) + + lora_result = lora_embedding(torch.cat(inputs)) + + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = embedding(input_) + after_a = F.embedding( + input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, vocab_size), + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size, + lora_config.lora_extra_vocab_size) + + lora_result = lora_embedding(torch.cat(inputs)) + expected_result = embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +# @pytest.mark.skip( +# reason="Fails when loras are in any slot other than the first.") +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("device", TPU_DEVICES) +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +@pytest.mark.parametrize("stage", STAGES) +def test_embeddings_with_new_embeddings(dist_init, num_loras, device, + vocab_size, stage) -> None: + + torch.set_default_device(device) + max_loras = 8 + punica_wrapper = PunicaWrapper(8192, 256, device) + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(vocab_size, 256) + embedding_data = torch.rand_like(embedding.weight.data) + embedding.weight.data = embedding_data + embedding.weight.data[vocab_size:, :] = 0 + expanded_embedding = VocabParallelEmbedding( + vocab_size + lora_config.lora_extra_vocab_size * max_loras, + 256, + org_num_embeddings=vocab_size) + expanded_embedding.weight.data[:vocab_size, :] = embedding_data + # We need to deepcopy the embedding as it will be modified + # in place + lora_embedding = VocabParallelEmbeddingWithLoRA( + deepcopy(expanded_embedding)) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return expanded_embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + expanded_embedding, lora_embedding = create_random_embedding_layer() + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=torch.zeros( + (256, vocab_size + lora_config.lora_extra_vocab_size)), + generate_embeddings_tensor=256, + ) + + lora_embedding.set_mapping(punica_wrapper) + # All embeddings tensors have the same shape. + embeddings_tensors = [ + lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) + ] + embeddings_tensor_len = embeddings_tensors[0].shape[0] + + # Add empty embeddings_tensors for unoccupied lora slots. + for _ in range(max_loras - len(embeddings_tensors)): + embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape)) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, vocab_size), + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size, + lora_config.lora_extra_vocab_size) + original_inputs = deepcopy(inputs) + + # Force some of the inputs to be in the extended embeddings range + # to guarantee that their behavior is tested. + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + embedding_id = lora_id - 1 + input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) + original_input_[-1] = vocab_size + input_[-2] = vocab_size + ( + (embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = vocab_size + embeddings_tensor_len - 1 + + expanded_embedding.weight[vocab_size:vocab_size + + (embeddings_tensor_len * + max_loras)] = torch.cat(embeddings_tensors) + + lora_result = lora_embedding(torch.cat(original_inputs)) + + expected_results: List[torch.Tensor] = [] + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + lora = lora_dict[lora_id] + result = expanded_embedding(input_) + after_a = F.embedding( + original_input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, vocab_size), + device=device) + original_inputs = deepcopy(inputs) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size, + lora_config.lora_extra_vocab_size) + lora_result = lora_embedding(torch.cat(original_inputs)) + expected_result = expanded_embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("device", TPU_DEVICES) +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) +@pytest.mark.parametrize("stage", STAGES) +def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, + stage) -> None: + + torch.set_default_device(device) + max_loras = 8 + punica_wrapper = PunicaWrapper(8192, 256, device) + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def _pretest(): + linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, + 1024, + vocab_size, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + linear.weight.data[:, vocab_size:] = 0 + logits_processor = LogitsProcessor( + vocab_size + lora_config.lora_extra_vocab_size, vocab_size) + lora_logits_processor = LogitsProcessorWithLoRA( + logits_processor, 1024, linear.weight.dtype, linear.weight.device, + None) + lora_logits_processor.create_lora_weights(max_loras, lora_config) + + return linear, logits_processor, lora_logits_processor + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, logits_processor, lora_logits_processor = _pretest() + lora_logits_processor.set_mapping(punica_wrapper) + # NOTE: all the generated loras share the same embeddings tensor. + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_logits_processor, + layer_weights=linear.weight, + generate_embeddings_tensor=1024, + ) + embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor + embeddings_tensor_len = embeddings_tensor.shape[0] + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=8 * num_loras, # * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) + input_ = torch.rand(20, 1024, dtype=torch.float16) + + lora_result = lora_logits_processor._get_logits( + hidden_states=torch.cat(inputs), + lm_head=linear, + embedding_bias=None) + + original_lm_head = deepcopy(linear) + + linear.weight[logits_processor. + org_vocab_size:logits_processor.org_vocab_size + + embeddings_tensor_len] = embeddings_tensor + + logits_processor.org_vocab_size = (vocab_size + + lora_config.lora_extra_vocab_size) + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = logits_processor._get_logits(hidden_states=input_, + lm_head=linear, + embedding_bias=None) + result[:, vocab_size + embeddings_tensor_len:] = float("-inf") + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + logits_processor.org_vocab_size = vocab_size + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_logits_processor.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=8 * num_loras * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + vocab_size, + lora_config.lora_extra_vocab_size, + ) + + lora_result = lora_logits_processor._get_logits( + hidden_states=torch.cat(inputs), + lm_head=original_lm_head, + embedding_bias=None)[:, :vocab_size] + expected_result = logits_processor._get_logits( + hidden_states=torch.cat(inputs), + lm_head=original_lm_head, + embedding_bias=None) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("device", TPU_DEVICES) +@pytest.mark.parametrize("stage", STAGES) +def test_linear_replicated(dist_init, num_loras, device, stage) -> None: + + torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_linear_replicated_layer(): + + linear = ReplicatedLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = ReplicatedLinearWithLoRA(linear) + + lora_linear.create_lora_weights(max_loras, lora_config) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, lora_linear = create_random_linear_replicated_layer() + lora_linear.set_mapping(punica_wrapper) + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = linear(input_)[0] + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("orientation", ["row", "column"]) +@pytest.mark.parametrize("fully_shard", [True, False]) +@pytest.mark.parametrize("device", TPU_DEVICES) +@pytest.mark.parametrize("stage", STAGES) +def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, + device, stage) -> None: + + torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16) + + def create_random_linear_parallel_layer(): + if orientation == "row": + linear = RowParallelLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard + else RowParallelLinearWithShardedLoRA(linear)) + else: + linear = ColumnParallelLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = (ColumnParallelLinearWithLoRA(linear) + if not fully_shard else + ColumnParallelLinearWithShardedLoRA(linear)) + lora_linear.create_lora_weights(max_loras, lora_config) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, lora_linear = create_random_linear_parallel_layer() + lora_linear.set_mapping(punica_wrapper) + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = linear(input_)[0] + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("repeats", [1, 2, 3]) +@pytest.mark.parametrize("fully_shard", [True, False]) +@pytest.mark.parametrize("device", TPU_DEVICES) +@pytest.mark.parametrize("stage", STAGES) +def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, + device, stage) -> None: + + torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16) + + def create_column_parallel_packed_layer(): + if repeats == 2: + linear = MergedColumnParallelLinear(4096, [4096] * repeats, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = (MergedColumnParallelLinearWithLoRA(linear) + if not fully_shard else + MergedColumnParallelLinearWithShardedLoRA(linear)) + elif repeats == 3: + linear = QKVParallelLinear(4096, + 64, + 32, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = (MergedQKVParallelLinearWithLora(linear) + if not fully_shard else + MergedQKVParallelLinearWithShardedLora(linear)) + else: + linear = QKVParallelLinear(4096, + 64, + 32, + bias=False, + params_dtype=torch.float16) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = QKVParallelLinearWithLora( + linear + ) if not fully_shard else QKVParallelLinearWithShardedLora(linear) + + @dataclass + class FakeConfig: + hidden_size = 4096 + num_key_value_heads = 32 + num_attention_heads = 32 + + lora_linear.create_lora_weights(max_loras, + lora_config, + model_config=FakeConfig()) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + + linear, lora_linear = create_column_parallel_packed_layer() + lora_linear.set_mapping(punica_wrapper) + lora_dict, sublora_dict = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + repeats=repeats, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results: List[torch.Tensor] = [] + for input_, lora_id in zip(inputs, prompt_mapping): + result = linear(input_)[0] + subloras = sublora_dict[lora_id] + for i, sublora in enumerate(subloras): + result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * + (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * + sublora.scaling) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float16, + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=stage) + + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + # lora_linear.set_mapping(*mapping_info) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 8]) +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0), + (6.0, 1.0)]) +@pytest.mark.parametrize("max_position", [11, 4096, 32768]) +@pytest.mark.parametrize("is_neox_style", [True, False]) +@pytest.mark.parametrize("rotary_dim", [None, 32]) +@pytest.mark.parametrize("head_size", [32, 108]) +@pytest.mark.parametrize("seq_len", [11, 1024]) +def test_rotary_embedding_long_context(dist_init, num_loras, device, + scaling_factors, max_position, + is_neox_style, rotary_dim, head_size, + seq_len) -> None: + dtype = torch.float16 + seed = 0 + current_platform.seed_everything(seed) + torch.set_default_device(device) + punica_wrapper = PunicaWrapper(8192, 256, device) + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + long_lora_scaling_factors=scaling_factors, + lora_dtype=dtype) + + if rotary_dim is None: + rotary_dim = head_size + base = 10000 + batch_size = 5 * num_loras + num_heads = 7 + + # Verify lora is equivalent to linear scaling rotary embedding. + rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + ) + lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) + lora_rope.set_mapping(punica_wrapper) + lora_rope.create_lora_weights(max_loras, lora_config) + linear_rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_style, { + "rope_type": "linear", + "factor": scaling_factors + }) + linear_rope = linear_rope.to(dtype=dtype) + id_to_index = get_random_id_to_index(num_loras, max_loras) + _, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=batch_size, + input_size=(1, max_position), + input_range=(0, lora_config.lora_extra_vocab_size), + input_type=torch.float16, + device=device) + + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + long_lora_context = LongContextLoRAContext(list(scaling_factors), + rotary_dim) + + next_expected_offset = 0 + # Make sure the offset is correct. + scaling_factor_to_offset = lora_rope.scaling_factor_to_offset + for scaling_factor, offset in scaling_factor_to_offset.items(): + assert offset == next_expected_offset + next_expected_offset += scaling_factor * max_position + + for i in range(len(scaling_factors)): + long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get( + scaling_factors[i], 0) + punica_wrapper.update_metadata( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + long_lora_context=long_lora_context, + ) + # lora_rope.set_mapping(*mapping_info) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + ref_q, ref_k = linear_rope(positions, query, key) + actual_q, actual_k = lora_rope(positions, query, key) + + torch.allclose(ref_q, actual_q) + torch.allclose(ref_k, actual_k) + + +@pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("seed", list(range(256))) +def test_vocab_parallel_embedding_indices(tp_size, seed): + random.seed(seed) + vocab_size = random.randint(4000, 64000) + added_vocab_size = random.randint(0, 1024) + org_vocab_size = vocab_size - added_vocab_size + last_org_vocab_end_index = 0 + last_added_vocab_end_index = org_vocab_size + computed_vocab_size = 0 + computed_org_vocab_size = 0 + computed_added_vocab_size = 0 + vocab_size_padded = -1 + + all_org_tokens: List[int] = [] + all_added_tokens: List[int] = [] + token_ids: List[int] = [] + + for tp_rank in range(tp_size): + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", + return_value=tp_rank + ), patch( + "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", + return_value=tp_size): + vocab_embedding = VocabParallelEmbedding( + vocab_size, 1, org_num_embeddings=org_vocab_size) + vocab_size_padded = vocab_embedding.num_embeddings_padded + shard_indices = vocab_embedding.shard_indices + # Assert that the ranges are contiguous + assert shard_indices.org_vocab_start_index == last_org_vocab_end_index + assert (shard_indices.added_vocab_start_index == + last_added_vocab_end_index) + + # Ensure that we are not exceeding the vocab size + computed_vocab_size += shard_indices.num_elements_padded + computed_org_vocab_size += shard_indices.num_org_elements + computed_added_vocab_size += shard_indices.num_added_elements + + # Ensure that the ranges are not overlapping + all_org_tokens.extend( + range(shard_indices.org_vocab_start_index, + shard_indices.org_vocab_end_index)) + all_added_tokens.extend( + range(shard_indices.added_vocab_start_index, + shard_indices.added_vocab_end_index)) + + token_ids.extend( + range(shard_indices.org_vocab_start_index, + shard_indices.org_vocab_end_index)) + token_ids.extend([-1] * (shard_indices.num_org_elements_padded - + shard_indices.num_org_elements)) + token_ids.extend( + range(shard_indices.added_vocab_start_index, + shard_indices.added_vocab_end_index)) + token_ids.extend([-1] * (shard_indices.num_added_elements_padded - + shard_indices.num_added_elements)) + + last_org_vocab_end_index = shard_indices.org_vocab_end_index + last_added_vocab_end_index = shard_indices.added_vocab_end_index + + assert computed_vocab_size == vocab_size_padded + assert computed_org_vocab_size == org_vocab_size + assert computed_added_vocab_size == added_vocab_size + + # Ensure that the ranges are not overlapping + assert len(all_org_tokens) == len(set(all_org_tokens)) + assert len(all_added_tokens) == len(set(all_added_tokens)) + assert not set(all_org_tokens).intersection(set(all_added_tokens)) + + token_ids_tensor = torch.tensor(token_ids, dtype=torch.long) + reindex_mapping = vocab_embedding.get_sharded_to_full_mapping() + assert reindex_mapping is not None or tp_size == 1 + if reindex_mapping is not None: + reindexed_token_ids = token_ids_tensor[reindex_mapping] + expected = torch.tensor(list(range(0, vocab_size))) + assert reindexed_token_ids[:vocab_size].equal(expected) + assert torch.all(reindexed_token_ids[vocab_size:] == -1) + + +def test_get_masked_input_and_mask(): + x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + + # base tp 1 case, no padding + modified_x, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=0) + assert torch.equal(x, modified_x) + + # tp 2 case, no padding + modified_x_rank_0, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=0) + modified_x_rank_1, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=4, + org_vocab_end_index=8, + added_vocab_start_index=10, + added_vocab_end_index=12, + num_org_vocab_padding=0) + assert torch.equal(modified_x_rank_0, + torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])) + assert torch.equal(modified_x_rank_1, + torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])) + + # tp 4 case, no padding + modified_x_rank_0, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=0) + modified_x_rank_1, _ = get_masked_input_and_mask(x, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=0) + modified_x_rank_2, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=4, + org_vocab_end_index=6, + added_vocab_start_index=10, + added_vocab_end_index=11, + num_org_vocab_padding=0) + modified_x_rank_3, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=6, + org_vocab_end_index=8, + added_vocab_start_index=11, + added_vocab_end_index=12, + num_org_vocab_padding=0) + assert torch.equal(modified_x_rank_0, + torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])) + assert torch.equal(modified_x_rank_1, + torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])) + assert torch.equal(modified_x_rank_2, + torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])) + assert torch.equal(modified_x_rank_3, + torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])) + + # base tp 1 case, with padding + modified_x, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=2) + assert torch.equal(modified_x, + torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])) + + # tp 2 case, with padding + modified_x_rank_0, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=2) + modified_x_rank_1, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=4, + org_vocab_end_index=8, + added_vocab_start_index=10, + added_vocab_end_index=12, + num_org_vocab_padding=2) + assert torch.equal(modified_x_rank_0, + torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])) + assert torch.equal(modified_x_rank_1, + torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])) + + # tp 4 case, with padding + modified_x_rank_0, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=2) + modified_x_rank_1, _ = get_masked_input_and_mask(x, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=2) + modified_x_rank_2, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=4, + org_vocab_end_index=6, + added_vocab_start_index=10, + added_vocab_end_index=11, + num_org_vocab_padding=2) + modified_x_rank_3, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=6, + org_vocab_end_index=8, + added_vocab_start_index=11, + added_vocab_end_index=12, + num_org_vocab_padding=2) + assert torch.equal(modified_x_rank_0, + torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])) + assert torch.equal(modified_x_rank_1, + torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])) + assert torch.equal(modified_x_rank_2, + torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) + assert torch.equal(modified_x_rank_3, + torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])) From 5f0355bbb377a566764fb560e006ee25e7472696 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 21 Nov 2024 11:40:53 +0000 Subject: [PATCH 003/186] Removed extra print Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 7e0173936ac6..5a4d991da1b5 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1107,7 +1107,6 @@ def _get_logits( self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits - print("punica", logits.dtype) # LogitsProcessorWithLoRA always using bgmv self.punica_wrapper.add_lora_logits(logits, hidden_states, self.lora_a_stacked, From edd02c58f950d342e10b6b4e68e6b1d646fb795c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 22 Nov 2024 12:27:02 +0000 Subject: [PATCH 004/186] Made some minor shape-based fixes to the kernels Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla/lora_ops.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/lora/ops/xla/lora_ops.py b/vllm/lora/ops/xla/lora_ops.py index 8ad32dd4a77b..cd12f3659f47 100644 --- a/vllm/lora/ops/xla/lora_ops.py +++ b/vllm/lora/ops/xla/lora_ops.py @@ -30,14 +30,18 @@ def bgmv_expand( lora_indices_tensor: torch.Tensor, add_inputs: bool = True ): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze() - inputs = inputs.to(dtype=torch.float16) + selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + if add_inputs: - output_tensor[:] += outputs[:] + output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] else: - output_tensor[:] = outputs[:] + output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] def sgmv_shrink( inputs: torch.Tensor, @@ -68,10 +72,10 @@ def bgmv_shrink( lora_indices_tensor: torch.Tensor, scaling: float = 1.0 ): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze() + selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - output_tensor[:] = scaling * outputs[:] + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] def sgmv_expand_slice( inputs: torch.Tensor, @@ -109,8 +113,8 @@ def bgmv_expand_slice( slice_size: int, add_inputs: bool = True ): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze() - inputs = inputs.to(dtype=torch.float16) + selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) if add_inputs: From aff94f966320949ea9129b49488ae299ae15e633 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 22 Nov 2024 15:11:44 +0000 Subject: [PATCH 005/186] Added basic lora execution code Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 86 ++++++++++++++++++++++++++++++--- vllm/worker/tpu_worker.py | 20 ++++++-- 2 files changed, 96 insertions(+), 10 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 53541a2579ed..321a9d21a4ec 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import enum import time from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Set, Type, Union) from unittest.mock import patch @@ -17,8 +17,12 @@ from vllm.config import VllmConfig from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_lora from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) @@ -62,6 +66,8 @@ class ModelInputForTPU(ModelRunnerInputBase): num_samples: int n: List[int] seq_groups: List[List[int]] + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None is_first_multi_step: bool = True is_last_step: bool = True virtual_engine: int = 0 @@ -72,6 +78,8 @@ def as_broadcastable_tensor_dict( tensor_dict = { "token_ids": self.token_ids, "position_ids": self.position_ids, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, "input_lens": self.input_lens, "t": self.t, "p": self.p, @@ -122,6 +130,9 @@ def __init__( False, ) self.cached_step_outputs: List[torch.Tensor] = [] + + # LoRA support + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None smem_size = 512 * 1024 block_table_size = 4 * self.block_tables.size @@ -154,16 +165,37 @@ def load_model(self) -> None: model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() - model = ModelWrapper(model) + self.model = model + + if self.lora_config: + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + + max_pos_embeddings = self.model.config.max_position_embeddings + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.model_config.get_vocab_size(), + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + self.model = ModelWrapper(self.model) self.model = torch.compile(model, - backend="openxla", - fullgraph=True, - dynamic=False) + backend="openxla", + fullgraph=True, + dynamic=False) def get_model(self) -> nn.Module: return self.model.model - def _dummy_run( + def _dummy_run( # KRAI-TODO: Add lora config here self, batch_size: int, seq_len: int, @@ -600,6 +632,15 @@ def execute_model( num_steps: int = 1, ) -> List[SamplerOutput]: assert intermediate_tensors is None + + print(f"\e[0;31m SELF LORA CONFIG {self.lora_config} \033[0m") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + if not model_input.is_first_multi_step: if not model_input.is_last_step: return [] @@ -765,7 +806,38 @@ def execute_model( sampler_output = _make_decode_output(next_token_ids, model_input.seq_groups) return [sampler_output] - + + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + print("\e[0;31mSetting active loras\033[0m") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() class ModelWrapper(nn.Module): diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 1a5eaba09b94..eb04479a10c3 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Set import torch import torch_xla.core.xla_model as xm @@ -13,18 +13,19 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoRANotSupportedWorkerBase, WorkerBase, + WorkerBase, WorkerInput) logger = init_logger(__name__) -class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): +class TPUWorker(LocalOrDistributedWorkerBase): def __init__( self, @@ -85,6 +86,7 @@ def init_device(self) -> None: # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and # 30-40 graphs for decode. 128 is an arbitrary safe number. torch._dynamo.config.cache_size_limit = 128 + torch._dynamo.config.reorderable_logging_functions = set([print]) # Use persistent cache to avoid XLA recompilation. # NOTE(woosuk): Set per-rank cache path since different ranks # can have slightly different XLA graphs. @@ -287,6 +289,18 @@ def execute_worker(self, worker_input: WorkerInput) -> None: if src_indices.numel() > 0: attn_backend.copy_blocks(self.tpu_cache, (src_indices, dst_indices)) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() def _make_src_to_dst( From adfd1941961b906992d1b69554fd7b030bc8be6f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 25 Nov 2024 17:00:04 +0000 Subject: [PATCH 006/186] Replaced einsums with matmuls+reshaping for better xla compilation Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla/lora_ops.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/xla/lora_ops.py b/vllm/lora/ops/xla/lora_ops.py index cd12f3659f47..51167ddf1b6b 100644 --- a/vllm/lora/ops/xla/lora_ops.py +++ b/vllm/lora/ops/xla/lora_ops.py @@ -32,8 +32,11 @@ def bgmv_expand( ): selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: limit = 1 @@ -73,7 +76,10 @@ def bgmv_shrink( scaling: float = 1.0 ): selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) - outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] @@ -115,7 +121,11 @@ def bgmv_expand_slice( ): selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + if add_inputs: output_tensor[:, slice_offset:slice_offset+slice_size] += outputs[:] From 816a56c25b5d25644fb5f5b0abf23d2bdd10cb9a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 25 Nov 2024 17:02:26 +0000 Subject: [PATCH 007/186] Replaced inf/-inf with max/min since XLA doesn't allow `nan_to_num_()` to be called with infinities Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 5a4d991da1b5..e1b25ec69a5b 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1092,12 +1092,18 @@ def _get_logits( lora_logits[-1] = float("-inf") lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded + + # KRAI: Temporary change + neg_inf = torch.finfo(lora_logits.dtype).min + pos_inf = torch.finfo(lora_logits.dtype).max + lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], - ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), - posinf=float("inf"), - neginf=float("-inf"))) + ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, + posinf=pos_inf, + neginf=neg_inf)) + print(f"AKSHAT - After index select: {lora_logits.shape}, {indices_padded.shape}") # HPU needs special handling to prune out dummy samples. if current_platform.is_hpu(): From c8a51c81708cf82828e2531d3685c170bc084d11 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 25 Nov 2024 17:03:37 +0000 Subject: [PATCH 008/186] Added lora config to `_dummy_run()` Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 321a9d21a4ec..6918d8c32830 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -44,7 +44,7 @@ # FIXME(woosuk): A temporary hack to support `n > 1`. # This can significantly affect the performance if too large. _MAX_NUM_SAMPLES = 128 - +LORA_WARMUP_RANK = 8 # KRAI: TODO: Should this not be max rank - so we have better startup times? class ExecutionMode(enum.Enum): PREFILL = enum.auto() @@ -54,7 +54,6 @@ class ExecutionMode(enum.Enum): def is_prefill(self) -> bool: return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) - @dataclass(frozen=True) class ModelInputForTPU(ModelRunnerInputBase): token_ids: torch.Tensor @@ -282,6 +281,27 @@ def _dummy_run( # KRAI-TODO: Add lora config here t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 + + # Create a series of dummy loras and requests for them. Make to fill all lora slots. + if self.lora_config: + dummy_lora_requests: Set[LoRARequest] = set() + dummy_lora_mapping: LoRAMapping + + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for lora_id in range(1, self.lora_config.max_loras + 1): + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.add(dummy_lora_request) + dummy_lora_mapping = LoRAMapping( + [lora_id] * seq_len, [lora_id], is_prefill=exec_mode.is_prefill() + ) + self.set_active_loras(dummy_lora_requests, dummy_lora_mapping) # NOTE(woosuk): There are two stages of compilation: torch.compile and # XLA compilation. Using `mark_dynamic` can reduce the torch.compile From 51f929d30fc5c8ba5cbb5974eb8864d614666131 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 25 Nov 2024 17:04:04 +0000 Subject: [PATCH 009/186] Changed torch._dynamo config Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 6918d8c32830..9a4599ffaaeb 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -310,10 +310,9 @@ def _dummy_run( # KRAI-TODO: Add lora config here # be re-compiled for every different shapes. This overhead is inevitable # in the first run, but can be skipped afterwards as we cache the XLA # graphs in the disk (VLLM_XLA_CACHE_PATH). + torch._dynamo.config.capture_dynamic_output_shape_ops = True if exec_mode.is_prefill(): - # Prefll - torch._dynamo.mark_dynamic(token_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) + # Prefill torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) else: # Decode From 23d4a2417293994316c02f15337666e971faebdf Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 25 Nov 2024 17:55:33 +0000 Subject: [PATCH 010/186] Quick patch to allow non lora code to run Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 9a4599ffaaeb..7239b6b43c59 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -310,9 +310,13 @@ def _dummy_run( # KRAI-TODO: Add lora config here # be re-compiled for every different shapes. This overhead is inevitable # in the first run, but can be skipped afterwards as we cache the XLA # graphs in the disk (VLLM_XLA_CACHE_PATH). - torch._dynamo.config.capture_dynamic_output_shape_ops = True if exec_mode.is_prefill(): # Prefill + if self.lora_config is not None: + torch._dynamo.config.capture_dynamic_output_shape_ops = True + else: + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) else: # Decode From 47397a738ec4f7f195c9ad9f6b76fa9cf1a7ecc7 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 17 Jan 2025 15:23:58 +0000 Subject: [PATCH 011/186] Minor fixes Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 346 +++++++++++++++++++++++++ vllm/platforms/tpu.py | 9 + 2 files changed, 355 insertions(+) create mode 100644 vllm/lora/punica_wrapper/punica_tpu.py diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py new file mode 100644 index 000000000000..ffac5b2c362e --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -0,0 +1,346 @@ +from typing import Callable, Optional, Tuple, Union + +import torch + +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) + +from .punica_base import PunicaWrapperBase + + +# The platforms that are compatible with the PyTorch-native implementation can +# inherit this class +class PunicaWrapperTPU(PunicaWrapperBase): + """ + PunicaWrapperTPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_inputs, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_inputs, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_inputs) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + y = y.view_as(y_org) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 0b66b52713e9..0ef558850c0d 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -57,6 +57,15 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return not envs.VLLM_USE_V1 + @classmethod + def is_pin_memory_available(cls): + logger.warning("Pin memory is not supported on TPU.") + return False + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" + @classmethod def inference_mode(cls): return torch.no_grad() From 456eb3768ca79fa0657735d23e481d8f444a90da Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 22 Jan 2025 11:11:26 +0000 Subject: [PATCH 012/186] Replaced einsums with matmuls to allow xla compilation Signed-off-by: Akshat Tripathi --- vllm/lora/ops/torch_ops/lora_ops.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/torch_ops/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py index af79f98415cb..30240c5e0bc9 100644 --- a/vllm/lora/ops/torch_ops/lora_ops.py +++ b/vllm/lora/ops/torch_ops/lora_ops.py @@ -30,7 +30,9 @@ def bgmv_expand(inputs: torch.Tensor, if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -71,7 +73,9 @@ def bgmv_shrink(inputs: torch.Tensor, if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] @@ -107,7 +111,8 @@ def bgmv_expand_slice(inputs: torch.Tensor, inputs = inputs.to(dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) - outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) if add_inputs: output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] From eabc748e45742b05996f085cc5ae94563ff379a1 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 12:01:16 +0000 Subject: [PATCH 013/186] Removed xla ops for torch ops Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla/lora_ops.py | 133 ---------------------------------- 1 file changed, 133 deletions(-) delete mode 100644 vllm/lora/ops/xla/lora_ops.py diff --git a/vllm/lora/ops/xla/lora_ops.py b/vllm/lora/ops/xla/lora_ops.py deleted file mode 100644 index 51167ddf1b6b..000000000000 --- a/vllm/lora/ops/xla/lora_ops.py +++ /dev/null @@ -1,133 +0,0 @@ -import torch - -def sgmv_expand( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - add_inputs: bool = False -): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - - bgmv_expand( - inputs, - lora_b_weights, - output_tensor, - exploded_indices, - add_inputs - ) - - -def bgmv_expand( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True -): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) - inputs = inputs.to(dtype=output_tensor.dtype) - - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) - - limit = output_tensor.shape[0] - if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: - limit = 1 - - if add_inputs: - output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] - else: - output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] - -def sgmv_shrink( - inputs: torch.Tensor, - lora_a_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - scaling: float, -): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - - bgmv_shrink( - inputs, - lora_a_weights, - output_tensor, - exploded_indices, - scaling - ) - -def bgmv_shrink( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0 -): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) - - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) - - output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] - -def sgmv_expand_slice( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - slice_offset: int, - slice_size: int, - add_inputs: bool = False -): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor) - - bgmv_expand_slice( - inputs, - lora_b_weights, - output_tensor, - exploded_indices, - slice_offset, - slice_size, - add_inputs - ) - - -def bgmv_expand_slice( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True -): - selected_loras = lora_b_weights[lora_indices_tensor].squeeze(dim=1) - inputs = inputs.to(dtype=output_tensor.dtype) - - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) - - - if add_inputs: - output_tensor[:, slice_offset:slice_offset+slice_size] += outputs[:] - else: - output_tensor[:, slice_offset:slice_offset+slice_size] = outputs[:] \ No newline at end of file From ac9753e2711f297e59859e4ced0b4a2e0a9bb268 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 12:02:11 +0000 Subject: [PATCH 014/186] Removed old debug log points Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 1 - vllm/worker/tpu_model_runner.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index e1b25ec69a5b..a8e24a8a38d1 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1103,7 +1103,6 @@ def _get_logits( ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf)) - print(f"AKSHAT - After index select: {lora_logits.shape}, {indices_padded.shape}") # HPU needs special handling to prune out dummy samples. if current_platform.is_hpu(): diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 7239b6b43c59..3170f61a2183 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -656,8 +656,6 @@ def execute_model( ) -> List[SamplerOutput]: assert intermediate_tensors is None - print(f"\e[0;31m SELF LORA CONFIG {self.lora_config} \033[0m") - if self.lora_config: assert model_input.lora_requests is not None assert model_input.lora_mapping is not None @@ -839,7 +837,6 @@ def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - print("\e[0;31mSetting active loras\033[0m") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: From aa8b0fd904710b95cb78a7f2ab81cda6fc059330 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 12:02:35 +0000 Subject: [PATCH 015/186] Fixed bgmv/sgmv shape error Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index ffac5b2c362e..cd8349889ffd 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -270,7 +270,7 @@ def add_lora_linear(self, Args: y (torch.Tensor): Output tensor. Will be changed in-place. - x (torch.Tensor): Input tensor + x (torch.Tensor): Input tensor (B, S, E) lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. @@ -289,10 +289,11 @@ def add_lora_linear(self, r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default, consistent with the # triton op - buffer = tuple( - torch.zeros( - (x.size(0), r), dtype=torch.float32, device=x.device) - for _ in range(len(output_slices))) + buffer = torch.zeros( + (len(output_slices), x.size(1), r), + dtype=torch.float32, + device=x.device, + ) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) self.add_expand(y, buffer, From 124215fa30dfff6987d0293698dc0f6dbf54543a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 14:41:22 +0000 Subject: [PATCH 016/186] Fixed lora batching crash in warmup Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 3170f61a2183..cbb81017c252 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -298,9 +298,9 @@ def _dummy_run( # KRAI-TODO: Add lora config here self.lora_manager.add_dummy_lora(dummy_lora_request, rank=LORA_WARMUP_RANK) dummy_lora_requests.add(dummy_lora_request) - dummy_lora_mapping = LoRAMapping( - [lora_id] * seq_len, [lora_id], is_prefill=exec_mode.is_prefill() - ) + dummy_lora_mapping = LoRAMapping( + [lora_id] * batch_size * seq_len, [lora_id] * batch_size, is_prefill=exec_mode.is_prefill() + ) self.set_active_loras(dummy_lora_requests, dummy_lora_mapping) # NOTE(woosuk): There are two stages of compilation: torch.compile and @@ -384,7 +384,7 @@ def warmup_model( # Decode start = time.time() seq_len = 1 - batch_size = 8 # Must be in sync with _get_padded_batch_size() + batch_size = _get_padded_batch_size(1) while True: self._dummy_run(batch_size, seq_len, From e14825405fa838531e77a7d654ca43ed02605f80 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 15:02:13 +0000 Subject: [PATCH 017/186] Fixed shape issue in add_lora_linear() Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index cd8349889ffd..b40da63a3517 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -289,8 +289,9 @@ def add_lora_linear(self, r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default, consistent with the # triton op + batch_size, seq_len, _ = x.shape buffer = torch.zeros( - (len(output_slices), x.size(1), r), + (len(output_slices), batch_size * seq_len, r), dtype=torch.float32, device=x.device, ) From 494b35ebb5c8263aeb1e668b1f9eff26dc6a8e1f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 15:02:59 +0000 Subject: [PATCH 018/186] Fixed dynamic lora tensor shapes Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index cbb81017c252..4b00400216f6 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -320,13 +320,17 @@ def _dummy_run( # KRAI-TODO: Add lora config here torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) else: # Decode - torch._dynamo.mark_dynamic(token_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(input_lens, 0) + if self.lora_config is not None: + torch._dynamo.config.capture_dynamic_output_shape_ops = True + else: + pass + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(input_lens, 0) + torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. with set_forward_context(attn_metadata, self.vllm_config, 0): From 1dbfcd9f0ec76556e926332974732de925787458 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 23 Jan 2025 16:15:23 +0000 Subject: [PATCH 019/186] Fixed lora_input preparation for actual execution Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 101 ++++++++++++++++++++++++++++---- 1 file changed, 89 insertions(+), 12 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 4b00400216f6..b4ebec9e261a 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -65,8 +65,7 @@ class ModelInputForTPU(ModelRunnerInputBase): num_samples: int n: List[int] seq_groups: List[List[int]] - lora_mapping: Optional["LoRAMapping"] = None - lora_requests: Optional[Set[LoRARequest]] = None + lora_inputs: List[Tuple[Set[LoRARequest], LoRAMapping]] is_first_multi_step: bool = True is_last_step: bool = True virtual_engine: int = 0 @@ -77,8 +76,7 @@ def as_broadcastable_tensor_dict( tensor_dict = { "token_ids": self.token_ids, "position_ids": self.position_ids, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, + "lora_inputs": self.lora_inputs, "input_lens": self.input_lens, "t": self.t, "p": self.p, @@ -641,8 +639,81 @@ def prepare_model_input( list(metadata.seq_data.keys()) for metadata in seq_group_metadata_list ] - return ModelInputForTPU(input_tokens, input_positions, attn_metadata, - input_lens, t, p, num_samples, n, seq_groups) + + lora_inputs = [] + if self.load_config is not None: + lora_inputs = self._prepare_lora_input(seq_group_metadata_list, is_prompt, padded_batch_size) + + return ModelInputForTPU( + token_ids=input_tokens, + position_ids=input_positions, + attn_metadata=attn_metadata, + input_lens=input_lens, + t=t, + p=p, + num_samples=num_samples, + n=n, + seq_groups=seq_groups, + lora_inputs=lora_inputs + ) + + def _prepare_lora_input( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + is_prefill: bool, + padded_batch_size: int) -> List[Tuple[Set[LoRARequest], LoRAMapping]]: + """ + Prepares a list of LoRA inputs. If we're decoding then the list will only have 1 item, + otherwise there'll be an item for each sequence + """ + + lora_input = [] + if is_prefill: + for seq in seq_group_metadata_list: + lora_id = seq.lora_int_id + query_len = seq.token_chunk_size + padded_query_len = _get_padded_prefill_len(query_len) + + index_mapping = [lora_id] * padded_query_len + prompt_mapping = [lora_id] + + lora_request = set() + if seq.lora_request is not None: + lora_request.add(seq.lora_request) + + lora_input.append(( + lora_request, + LoRAMapping( + index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=True + ) + )) + else: + lora_request = set() + index_mapping = [] + prompt_mapping = [] + for seq in seq_group_metadata_list: + lora_id = seq.lora_int_id + + index_mapping += [lora_id] + prompt_mapping += [lora_id] + + if seq.lora_request is not None: + lora_request.add(seq.lora_request) + + index_mapping += [0] * (padded_batch_size - len(seq_group_metadata_list)) + prompt_mapping += [0] * (padded_batch_size - len(seq_group_metadata_list)) + + lora_input.append(( + lora_request, + LoRAMapping( + index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=False + ) + )) + + return lora_input def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: @@ -660,12 +731,6 @@ def execute_model( ) -> List[SamplerOutput]: assert intermediate_tensors is None - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - if not model_input.is_first_multi_step: if not model_input.is_last_step: return [] @@ -741,6 +806,12 @@ def execute_model( input_lens = model_input.input_lens[i:i + 1].to(self.device) t = model_input.t[i:i + 1].to(self.device) p = model_input.p[i:i + 1].to(self.device) + + if self.lora_config is not None: + assert len(model_input.lora_inputs) == batch_size + lora_requests, lora_mapping = model_input.lora_inputs[i] + self.set_active_loras(lora_requests, lora_mapping) + with set_forward_context(model_input.attn_metadata, self.vllm_config, model_input.virtual_engine): @@ -790,6 +861,12 @@ def execute_model( t = model_input.t.to(self.device) p = model_input.p.to(self.device) input_lens = model_input.input_lens.to(self.device) + + if self.lora_config is not None: + assert len(model_input.lora_inputs) == 1 + lora_requests, lora_mapping = model_input.lora_inputs[0] + self.set_active_loras(lora_requests, lora_mapping) + for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping with set_forward_context(model_input.attn_metadata, From 1bb2578b96abd4de507bd8776c00c37c14ae68a6 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 24 Jan 2025 16:42:22 +0000 Subject: [PATCH 020/186] Fixed wrong model bug Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index b4ebec9e261a..8cfec1fb9a42 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -184,7 +184,7 @@ def load_model(self) -> None: self.model = self.lora_manager.create_lora_manager(self.model) self.model = ModelWrapper(self.model) - self.model = torch.compile(model, + self.model = torch.compile(self.model, backend="openxla", fullgraph=True, dynamic=False) @@ -321,7 +321,6 @@ def _dummy_run( # KRAI-TODO: Add lora config here if self.lora_config is not None: torch._dynamo.config.capture_dynamic_output_shape_ops = True else: - pass torch._dynamo.mark_dynamic(token_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(input_lens, 0) From ddc4cbc5117be3cfbc148979093c7a74026d6063 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 24 Jan 2025 16:51:49 +0000 Subject: [PATCH 021/186] Moved if statements outside of for loops in PunicaWrapperTPU Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 54 +++++++------------------- 1 file changed, 14 insertions(+), 40 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index b40da63a3517..b0a8149d5b7d 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -111,43 +111,6 @@ def _expand_slice_decode( bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) - def _apply_expand( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: int, - y_slice_size: int, - add_inputs: bool = True, - ): - """ - Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` - computation, which is suitable for the - GEMM of lora'b. - """ - - expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else - self._expand_slice_decode) - expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) - - def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, scale: float): - """ - Perform the ` y+=x@w_t_all` computation, which is suitable for the - GEMM of lora'a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `_shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the _shrink_decode function - should be called. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) - shrink_fun(y, x, w_t_all, scale) - y = y.view_as(y_org) - def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs): @@ -170,10 +133,19 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], """ x = x.view(-1, x.shape[-1]) + + shrink_fun: Callable = (self._shrink_prefill if self.is_prefill else self._shrink_decode) + # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) + y_s = y[slice_idx] + lora_s = lora_a_stacked[slice_idx] + + y_org = y_s + y_s = y_s.view(-1, y_s.shape[-1]) + + shrink_fun(y_s, x, lora_s, scale) + y_s = y_s.view_as(y_org) def add_expand(self, y: torch.Tensor, @@ -203,6 +175,8 @@ def add_expand(self, output_slices (Tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ + expand_slice_fun: Callable = (self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode) + y_org = y y = y.view(-1, y.shape[-1]) offset_left = offset_start @@ -210,7 +184,7 @@ def add_expand(self, self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): - self._apply_expand( + expand_slice_fun( y, x[slice_idx], lora_b_stacked[slice_idx], From 48a69442e52638fd461ec37fd451651ff6ef6a4a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 28 Jan 2025 14:50:58 +0000 Subject: [PATCH 022/186] Added early exits to PunicaWrapperTPU lora functions Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index b0a8149d5b7d..64bd8fa16917 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -48,6 +48,8 @@ def _shrink_decode( w_t_all: torch.Tensor, scale: float, ): + if self.no_lora: + return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) def _expand_prefill( @@ -75,6 +77,8 @@ def _expand_decode( w_t_all: torch.Tensor, add_inputs: bool, ): + if self.no_lora: + return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) def _expand_slice_prefill( @@ -108,6 +112,8 @@ def _expand_slice_decode( y_slice_size: int, add_inputs: bool, ): + if self.no_lora: + return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) From 7802e842e425fc9cec8d6c3e873132a1be9d0deb Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 30 Jan 2025 12:33:39 +0000 Subject: [PATCH 023/186] Added torch ops for tpu (Static prefill sizes) Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/__init__.py | 13 +++ vllm/lora/ops/xla_ops/lora_ops.py | 118 +++++++++++++++++++++++++ vllm/lora/punica_wrapper/punica_tpu.py | 2 +- 3 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 vllm/lora/ops/xla_ops/__init__.py create mode 100644 vllm/lora/ops/xla_ops/lora_ops.py diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py new file mode 100644 index 000000000000..4785af8520d3 --- /dev/null +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -0,0 +1,13 @@ +from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand # noqa: F401 +from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, + sgmv_shrink) + +__all__ = [ + "bgmv_expand", + "bgmv_expand_slice", + "bgmv_shrink", + "sgmv_expand", + "sgmv_expand_slice", + "sgmv_shrink", +] diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py new file mode 100644 index 000000000000..5dc0c98bbb48 --- /dev/null +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -0,0 +1,118 @@ +import torch + + +def sgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + inputs.size(0)) + + bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + if add_inputs: + output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] + else: + output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] + + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + inputs.size(0)) + + bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, + scaling) + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + + +def sgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + inputs.size(0)) + + bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, + slice_offset, slice_size, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + inputs = inputs.to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + + if add_inputs: + output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 64bd8fa16917..b6739bd97bdb 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -2,7 +2,7 @@ import torch -from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, +from vllm.lora.ops.xla_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink) From ab5396ba28ebe970234e587d905d8720cc08f78c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 30 Jan 2025 17:34:42 +0000 Subject: [PATCH 024/186] XLA bgmv operations are now imported from the default torch_ops Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 65 +------------------------------ 1 file changed, 1 insertion(+), 64 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 5dc0c98bbb48..d6c630880644 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,5 +1,5 @@ import torch - +from ..torch_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink def sgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, @@ -17,31 +17,6 @@ def sgmv_expand(inputs: torch.Tensor, bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs) - -def bgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(dim=1) - inputs = inputs.to(dtype=output_tensor.dtype) - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) - - limit = output_tensor.shape[0] - if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: - limit = 1 - - if add_inputs: - output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] - else: - output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] - - def sgmv_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, @@ -61,23 +36,6 @@ def sgmv_shrink( scaling) -def bgmv_shrink(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(dim=1) - inputs = inputs.to(dtype=output_tensor.dtype) - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) - - output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] - - def sgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -95,24 +53,3 @@ def sgmv_expand_slice(inputs: torch.Tensor, bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, slice_offset, slice_size, add_inputs) - - -def bgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) - inputs = inputs.to(dtype=output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(dim=1) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) - - if add_inputs: - output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] - else: - output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] From fdf29d33d8d7632b8bebf162ae8144ec7228719b Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 31 Jan 2025 16:15:22 +0000 Subject: [PATCH 025/186] Removed TODOs Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 10 +++++++--- vllm/worker/tpu_model_runner.py | 5 ++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index a8e24a8a38d1..16c5771a5123 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1093,9 +1093,13 @@ def _get_logits( lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded - # KRAI: Temporary change - neg_inf = torch.finfo(lora_logits.dtype).min - pos_inf = torch.finfo(lora_logits.dtype).max + if current_platform.is_tpu(): + # Because nan_to_num_ doesn't work with actual -inf values on TPU + neg_inf = torch.finfo(lora_logits.dtype).min + pos_inf = torch.finfo(lora_logits.dtype).max + else: + neg_inf = float("-inf") + pos_inf = float("inf") lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 8cfec1fb9a42..248bcc4627c1 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -44,7 +44,6 @@ # FIXME(woosuk): A temporary hack to support `n > 1`. # This can significantly affect the performance if too large. _MAX_NUM_SAMPLES = 128 -LORA_WARMUP_RANK = 8 # KRAI: TODO: Should this not be max rank - so we have better startup times? class ExecutionMode(enum.Enum): PREFILL = enum.auto() @@ -192,7 +191,7 @@ def load_model(self) -> None: def get_model(self) -> nn.Module: return self.model.model - def _dummy_run( # KRAI-TODO: Add lora config here + def _dummy_run( self, batch_size: int, seq_len: int, @@ -294,7 +293,7 @@ def _dummy_run( # KRAI-TODO: Add lora config here lora_path="/not/a/real/path", ) self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) + rank=self.lora_config.max_lora_rank) dummy_lora_requests.add(dummy_lora_request) dummy_lora_mapping = LoRAMapping( [lora_id] * batch_size * seq_len, [lora_id] * batch_size, is_prefill=exec_mode.is_prefill() From c2b4139b29b77347c875a1479ccb4deea2c00f78 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 31 Jan 2025 17:44:21 +0000 Subject: [PATCH 026/186] Removed old code Signed-off-by: Akshat Tripathi --- tests/lora/test_layers_tpu.py | 1220 --------------------------------- 1 file changed, 1220 deletions(-) delete mode 100644 tests/lora/test_layers_tpu.py diff --git a/tests/lora/test_layers_tpu.py b/tests/lora/test_layers_tpu.py deleted file mode 100644 index 29f732c621af..000000000000 --- a/tests/lora/test_layers_tpu.py +++ /dev/null @@ -1,1220 +0,0 @@ -import random -from copy import deepcopy -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple -from unittest.mock import patch - -import pytest -import torch -import torch.nn.functional as F - -from vllm.config import LoRAConfig -from vllm.lora.fully_sharded_layers import ( - ColumnParallelLinearWithShardedLoRA, - MergedColumnParallelLinearWithShardedLoRA, - MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, - RowParallelLinearWithShardedLoRA) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, - LinearScalingRotaryEmbeddingWithLora, - LogitsProcessorWithLoRA, LoRAMapping, - MergedColumnParallelLinearWithLoRA, - MergedQKVParallelLinearWithLora, - QKVParallelLinearWithLora, - ReplicatedLinearWithLoRA, - RowParallelLinearWithLoRA, - VocabParallelEmbeddingWithLoRA) -# yapf: enable -from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, - PackedLoRALayerWeights) -from vllm.lora.punica import PunicaWrapper -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) -from vllm.model_executor.utils import set_random_seed -from vllm.platforms import current_platform - -from .utils import DummyLoRAManager - -TOLERANCES = { - torch.float16: (5e-3, 5e-3), - torch.float32: (5e-3, 5e-3), - torch.bfloat16: (3e-2, 2e-2), -} -TPU_DEVICES = [ - f"xla:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - -# We will launch different triton kernels between the prefill and decode -# stages, so we need to verify this. prefill stage(True) or decode stage(False) -STAGES = [True, False] - - -def get_random_id_to_index(num_loras: int, - num_slots: int, - log: bool = True) -> List[Optional[int]]: - """Creates a random lora_id_to_index mapping. - - Args: - num_loras: The number of active loras in the mapping. - num_slots: The number of slots in the mapping. Must be larger - than num_loras. - log: Whether to log the output. - """ - - if num_loras > num_slots: - raise ValueError( - f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " - "num_loras must be less than or equal to num_slots.") - - slots: List[Optional[int]] = [None] * num_slots - random_slot_selections = (torch.randperm(num_slots, device="cpu")[:num_loras]).tolist() - for lora_id, slot_idx in enumerate(random_slot_selections, start=1): - slots[slot_idx] = lora_id - - if log: - print(f"Created lora_id_to_index mapping: {slots}.") - - return slots - - -def populate_loras( - id_to_index: List[Optional[int]], - layer: BaseLayerWithLoRA, - layer_weights: torch.Tensor, - generate_embeddings_tensor: int = 0, - repeats: int = 1, -) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]: - """This method populates the lora layers with lora weights. - - Args: - id_to_index: a list of lora ids. The index of the lora id - represents which memory slot the lora matrices are - stored in. A None value indicates a free slot. - layer: the LoRAlayer to populate. - layer_weights: the PyTorch tensor containing the layer's - weights. - generate_embeddings_tensor: whether to generate an - embeddings tensor for each LoRA. - repeats: must only be set for column parallel packed - layers. Indicates the number of loras to compose - together to create a single lora layer. - """ - - # Dictionary that maps the lora ID to the - # corresponding lora weights. - lora_dict: Dict[int, LoRALayerWeights] = dict() - - # Dictionary that maps the lora ID to the - # corresponding subloras. - sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() - - for slot_idx, lora_id in enumerate(id_to_index): - if lora_id is not None: - subloras: List[LoRALayerWeights] = [] - sublora_len = layer_weights.shape[0] // repeats - for i in range(repeats): - sublora = DummyLoRAManager( - layer_weights.device).init_random_lora( - module_name=f"fake_{i}", - weight=layer_weights, - generate_embeddings_tensor=generate_embeddings_tensor, - ) - sublora.lora_b = sublora.lora_b[:, (sublora_len * - i):(sublora_len * (i + 1))] - sublora.optimize() - subloras.append(sublora) - - lora = PackedLoRALayerWeights.pack( - subloras) if repeats > 1 else subloras[0] - - layer.set_lora( - slot_idx, - lora_a=lora.lora_a, - lora_b=lora.lora_b, - embeddings_tensor=lora.embeddings_tensor, - ) - - lora_dict[lora_id] = lora - sublora_dict[lora_id] = subloras - - return lora_dict, sublora_dict - - -def create_random_inputs( - active_lora_ids: List[int], - num_inputs: int, - input_size: Tuple[int, ...], - input_range: Tuple[float, float], - input_type: torch.dtype = torch.int, - device: torch.device = "xla" -) -> Tuple[List[torch.Tensor], List[int], List[int]]: - """Creates random inputs. - - Args: - active_lora_ids: lora IDs of active lora weights. - num_inputs: the number of inputs to create. - input_size: the size of each individual input. - input_range: the range of values to include in the input. - input_range[0] <= possible input values < input_range[1] - input_type: the type of values in the input. - """ - - low, high = input_range - - inputs: List[torch.Tensor] = [] - index_mapping: List[int] = [] - prompt_mapping: List[int] = [] - - for _ in range(num_inputs): - if input_type == torch.int: - inputs.append( - torch.randint(low=int(low), - high=int(high), - size=input_size, - device=device)) - else: - inputs.append( - torch.rand(size=input_size, dtype=input_type, device=device) * - high + low) - - lora_id = random.choice(active_lora_ids) - index_mapping += [lora_id] * input_size[0] - prompt_mapping += [lora_id] - - return inputs, index_mapping, prompt_mapping - - -@torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", TPU_DEVICES) -@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) -@pytest.mark.parametrize("stage", STAGES) -def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: - torch.set_default_device(device) - max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) - - def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(vocab_size, 256) - embedding.weight.data = torch.rand_like(embedding.weight.data) - embedding.weight.data[vocab_size:, :] = 0 - lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) - lora_embedding.create_lora_weights(max_loras, lora_config) - - return embedding, lora_embedding - - for i in range(10): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - embedding, lora_embedding = create_random_embedding_layer() - lora_embedding.set_mapping(punica_wrapper) - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_embedding, - layer_weights=embedding.weight.T, - ) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=num_loras * 3, - input_size=(200, ), - input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) - - lora_result = lora_embedding(torch.cat(inputs)) - - expected_results: List[torch.Tensor] = [] - for input_, lora_id in zip(inputs, prompt_mapping): - lora = lora_dict[lora_id] - result = embedding(input_) - after_a = F.embedding( - input_, - lora.lora_a, - ) - result += (after_a @ lora.lora_b) - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_embedding.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=num_loras * 3, - input_size=(200, ), - input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) - - lora_result = lora_embedding(torch.cat(inputs)) - expected_result = embedding(torch.cat(inputs)) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - -@torch.inference_mode() -# @pytest.mark.skip( -# reason="Fails when loras are in any slot other than the first.") -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", TPU_DEVICES) -@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) -@pytest.mark.parametrize("stage", STAGES) -def test_embeddings_with_new_embeddings(dist_init, num_loras, device, - vocab_size, stage) -> None: - - torch.set_default_device(device) - max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) - - def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(vocab_size, 256) - embedding_data = torch.rand_like(embedding.weight.data) - embedding.weight.data = embedding_data - embedding.weight.data[vocab_size:, :] = 0 - expanded_embedding = VocabParallelEmbedding( - vocab_size + lora_config.lora_extra_vocab_size * max_loras, - 256, - org_num_embeddings=vocab_size) - expanded_embedding.weight.data[:vocab_size, :] = embedding_data - # We need to deepcopy the embedding as it will be modified - # in place - lora_embedding = VocabParallelEmbeddingWithLoRA( - deepcopy(expanded_embedding)) - lora_embedding.create_lora_weights(max_loras, lora_config) - - return expanded_embedding, lora_embedding - - for i in range(10): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - expanded_embedding, lora_embedding = create_random_embedding_layer() - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_embedding, - layer_weights=torch.zeros( - (256, vocab_size + lora_config.lora_extra_vocab_size)), - generate_embeddings_tensor=256, - ) - - lora_embedding.set_mapping(punica_wrapper) - # All embeddings tensors have the same shape. - embeddings_tensors = [ - lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) - ] - embeddings_tensor_len = embeddings_tensors[0].shape[0] - - # Add empty embeddings_tensors for unoccupied lora slots. - for _ in range(max_loras - len(embeddings_tensors)): - embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape)) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=num_loras * 3, - input_size=(200, ), - input_range=(1, vocab_size), - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) - original_inputs = deepcopy(inputs) - - # Force some of the inputs to be in the extended embeddings range - # to guarantee that their behavior is tested. - for input_, original_input_, lora_id in zip(inputs, original_inputs, - prompt_mapping): - embedding_id = lora_id - 1 - input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) - original_input_[-1] = vocab_size - input_[-2] = vocab_size + ( - (embedding_id + 1) * embeddings_tensor_len - 1) - original_input_[-2] = vocab_size + embeddings_tensor_len - 1 - - expanded_embedding.weight[vocab_size:vocab_size + - (embeddings_tensor_len * - max_loras)] = torch.cat(embeddings_tensors) - - lora_result = lora_embedding(torch.cat(original_inputs)) - - expected_results: List[torch.Tensor] = [] - for input_, original_input_, lora_id in zip(inputs, original_inputs, - prompt_mapping): - lora = lora_dict[lora_id] - result = expanded_embedding(input_) - after_a = F.embedding( - original_input_, - lora.lora_a, - ) - result += (after_a @ lora.lora_b) - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_embedding.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=num_loras * 3, - input_size=(200, ), - input_range=(1, vocab_size), - device=device) - original_inputs = deepcopy(inputs) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - vocab_size, - lora_config.lora_extra_vocab_size) - lora_result = lora_embedding(torch.cat(original_inputs)) - expected_result = expanded_embedding(torch.cat(inputs)) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - -@torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", TPU_DEVICES) -@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) -@pytest.mark.parametrize("stage", STAGES) -def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, - stage) -> None: - - torch.set_default_device(device) - max_loras = 8 - punica_wrapper = PunicaWrapper(8192, 256, device) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) - - def _pretest(): - linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, - 1024, - vocab_size, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - linear.weight.data[:, vocab_size:] = 0 - logits_processor = LogitsProcessor( - vocab_size + lora_config.lora_extra_vocab_size, vocab_size) - lora_logits_processor = LogitsProcessorWithLoRA( - logits_processor, 1024, linear.weight.dtype, linear.weight.device, - None) - lora_logits_processor.create_lora_weights(max_loras, lora_config) - - return linear, logits_processor, lora_logits_processor - - for i in range(10): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - linear, logits_processor, lora_logits_processor = _pretest() - lora_logits_processor.set_mapping(punica_wrapper) - # NOTE: all the generated loras share the same embeddings tensor. - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_logits_processor, - layer_weights=linear.weight, - generate_embeddings_tensor=1024, - ) - embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor - embeddings_tensor_len = embeddings_tensor.shape[0] - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=8 * num_loras, # * 3, - input_size=(1, 1024), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) - input_ = torch.rand(20, 1024, dtype=torch.float16) - - lora_result = lora_logits_processor._get_logits( - hidden_states=torch.cat(inputs), - lm_head=linear, - embedding_bias=None) - - original_lm_head = deepcopy(linear) - - linear.weight[logits_processor. - org_vocab_size:logits_processor.org_vocab_size + - embeddings_tensor_len] = embeddings_tensor - - logits_processor.org_vocab_size = (vocab_size + - lora_config.lora_extra_vocab_size) - expected_results: List[torch.Tensor] = [] - for input_, lora_id in zip(inputs, prompt_mapping): - lora = lora_dict[lora_id] - result = logits_processor._get_logits(hidden_states=input_, - lm_head=linear, - embedding_bias=None) - result[:, vocab_size + embeddings_tensor_len:] = float("-inf") - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling - expected_results.append(result) - expected_result = torch.cat(expected_results) - logits_processor.org_vocab_size = vocab_size - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_logits_processor.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=8 * num_loras * 3, - input_size=(1, 1024), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - vocab_size, - lora_config.lora_extra_vocab_size, - ) - - lora_result = lora_logits_processor._get_logits( - hidden_states=torch.cat(inputs), - lm_head=original_lm_head, - embedding_bias=None)[:, :vocab_size] - expected_result = logits_processor._get_logits( - hidden_states=torch.cat(inputs), - lm_head=original_lm_head, - embedding_bias=None) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - -@torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("device", TPU_DEVICES) -@pytest.mark.parametrize("stage", STAGES) -def test_linear_replicated(dist_init, num_loras, device, stage) -> None: - - torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) - max_loras = 8 - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16) - - def create_random_linear_replicated_layer(): - - linear = ReplicatedLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = ReplicatedLinearWithLoRA(linear) - - lora_linear.create_lora_weights(max_loras, lora_config) - - return linear, lora_linear - - for i in range(10): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - linear, lora_linear = create_random_linear_replicated_layer() - lora_linear.set_mapping(punica_wrapper) - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_linear, - layer_weights=linear.weight, - ) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=32 * num_loras, - input_size=(1, 4096), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) - - lora_result = lora_linear(torch.cat(inputs))[0] - - expected_results: List[torch.Tensor] = [] - for input_, lora_id in zip(inputs, prompt_mapping): - lora = lora_dict[lora_id] - result = linear(input_)[0] - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_linear.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=32 * num_loras, - input_size=(1, 4096), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) - - lora_result = lora_linear(torch.cat(inputs))[0] - expected_result = linear(torch.cat(inputs))[0] - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - -@torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("orientation", ["row", "column"]) -@pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", TPU_DEVICES) -@pytest.mark.parametrize("stage", STAGES) -def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device, stage) -> None: - - torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) - max_loras = 8 - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16) - - def create_random_linear_parallel_layer(): - if orientation == "row": - linear = RowParallelLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard - else RowParallelLinearWithShardedLoRA(linear)) - else: - linear = ColumnParallelLinear(4096, - 4096, - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (ColumnParallelLinearWithLoRA(linear) - if not fully_shard else - ColumnParallelLinearWithShardedLoRA(linear)) - lora_linear.create_lora_weights(max_loras, lora_config) - - return linear, lora_linear - - for i in range(10): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - linear, lora_linear = create_random_linear_parallel_layer() - lora_linear.set_mapping(punica_wrapper) - lora_dict, _ = populate_loras( - id_to_index, - layer=lora_linear, - layer_weights=linear.weight, - ) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=32 * num_loras, - input_size=(1, 4096), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) - - lora_result = lora_linear(torch.cat(inputs))[0] - - expected_results: List[torch.Tensor] = [] - for input_, lora_id in zip(inputs, prompt_mapping): - lora = lora_dict[lora_id] - result = linear(input_)[0] - result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - # Check that resetting the lora weights succeeds - - for slot_idx in range(max_loras): - lora_linear.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=32 * num_loras, - input_size=(1, 4096), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - - punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) - - lora_result = lora_linear(torch.cat(inputs))[0] - expected_result = linear(torch.cat(inputs))[0] - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - -@torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -@pytest.mark.parametrize("repeats", [1, 2, 3]) -@pytest.mark.parametrize("fully_shard", [True, False]) -@pytest.mark.parametrize("device", TPU_DEVICES) -@pytest.mark.parametrize("stage", STAGES) -def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage) -> None: - - torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) - max_loras = 8 - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16) - - def create_column_parallel_packed_layer(): - if repeats == 2: - linear = MergedColumnParallelLinear(4096, [4096] * repeats, - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (MergedColumnParallelLinearWithLoRA(linear) - if not fully_shard else - MergedColumnParallelLinearWithShardedLoRA(linear)) - elif repeats == 3: - linear = QKVParallelLinear(4096, - 64, - 32, - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = (MergedQKVParallelLinearWithLora(linear) - if not fully_shard else - MergedQKVParallelLinearWithShardedLora(linear)) - else: - linear = QKVParallelLinear(4096, - 64, - 32, - bias=False, - params_dtype=torch.float16) - linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = QKVParallelLinearWithLora( - linear - ) if not fully_shard else QKVParallelLinearWithShardedLora(linear) - - @dataclass - class FakeConfig: - hidden_size = 4096 - num_key_value_heads = 32 - num_attention_heads = 32 - - lora_linear.create_lora_weights(max_loras, - lora_config, - model_config=FakeConfig()) - - return linear, lora_linear - - for i in range(10): - set_random_seed(i) - - id_to_index = get_random_id_to_index(num_loras, max_loras) - - linear, lora_linear = create_column_parallel_packed_layer() - lora_linear.set_mapping(punica_wrapper) - lora_dict, sublora_dict = populate_loras( - id_to_index, - layer=lora_linear, - layer_weights=linear.weight, - repeats=repeats, - ) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=list(lora_dict.keys()), - num_inputs=32 * num_loras, - input_size=(1, 4096), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) - - lora_result = lora_linear(torch.cat(inputs))[0] - - expected_results: List[torch.Tensor] = [] - for input_, lora_id in zip(inputs, prompt_mapping): - result = linear(input_)[0] - subloras = sublora_dict[lora_id] - for i, sublora in enumerate(subloras): - result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * - (i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b * - sublora.scaling) - expected_results.append(result) - expected_result = torch.cat(expected_results) - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - for slot_idx in range(max_loras): - lora_linear.reset_lora(slot_idx) - - inputs, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=32 * num_loras, - input_size=(1, 4096), - input_range=(0, 1), - input_type=torch.float16, - device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=stage) - - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512, - lora_config.lora_extra_vocab_size, - ) - # lora_linear.set_mapping(*mapping_info) - - lora_result = lora_linear(torch.cat(inputs))[0] - expected_result = linear(torch.cat(inputs))[0] - - rtol, atol = TOLERANCES[lora_result.dtype] - torch.testing.assert_close(lora_result, - expected_result, - rtol=rtol, - atol=atol) - - -@torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 8]) -@pytest.mark.parametrize("device", ["cuda"]) -@pytest.mark.parametrize("scaling_factors", [(1.0, ), (4.0, ), (4.0, 8.0), - (6.0, 1.0)]) -@pytest.mark.parametrize("max_position", [11, 4096, 32768]) -@pytest.mark.parametrize("is_neox_style", [True, False]) -@pytest.mark.parametrize("rotary_dim", [None, 32]) -@pytest.mark.parametrize("head_size", [32, 108]) -@pytest.mark.parametrize("seq_len", [11, 1024]) -def test_rotary_embedding_long_context(dist_init, num_loras, device, - scaling_factors, max_position, - is_neox_style, rotary_dim, head_size, - seq_len) -> None: - dtype = torch.float16 - seed = 0 - current_platform.seed_everything(seed) - torch.set_default_device(device) - punica_wrapper = PunicaWrapper(8192, 256, device) - max_loras = 8 - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - long_lora_scaling_factors=scaling_factors, - lora_dtype=dtype) - - if rotary_dim is None: - rotary_dim = head_size - base = 10000 - batch_size = 5 * num_loras - num_heads = 7 - - # Verify lora is equivalent to linear scaling rotary embedding. - rope = get_rope( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - ) - lora_rope = LinearScalingRotaryEmbeddingWithLora(rope) - lora_rope.set_mapping(punica_wrapper) - lora_rope.create_lora_weights(max_loras, lora_config) - linear_rope = get_rope(head_size, rotary_dim, max_position, base, - is_neox_style, { - "rope_type": "linear", - "factor": scaling_factors - }) - linear_rope = linear_rope.to(dtype=dtype) - id_to_index = get_random_id_to_index(num_loras, max_loras) - _, index_mapping, prompt_mapping = create_random_inputs( - active_lora_ids=[0], - num_inputs=batch_size, - input_size=(1, max_position), - input_range=(0, lora_config.lora_extra_vocab_size), - input_type=torch.float16, - device=device) - - lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - long_lora_context = LongContextLoRAContext(list(scaling_factors), - rotary_dim) - - next_expected_offset = 0 - # Make sure the offset is correct. - scaling_factor_to_offset = lora_rope.scaling_factor_to_offset - for scaling_factor, offset in scaling_factor_to_offset.items(): - assert offset == next_expected_offset - next_expected_offset += scaling_factor * max_position - - for i in range(len(scaling_factors)): - long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get( - scaling_factors[i], 0) - punica_wrapper.update_metadata( - lora_mapping, - id_to_index, - max_loras, - 512, - lora_config.lora_extra_vocab_size, - long_lora_context=long_lora_context, - ) - # lora_rope.set_mapping(*mapping_info) - - positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) - key = torch.randn_like(query) - ref_q, ref_k = linear_rope(positions, query, key) - actual_q, actual_k = lora_rope(positions, query, key) - - torch.allclose(ref_q, actual_q) - torch.allclose(ref_k, actual_k) - - -@pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) -@pytest.mark.parametrize("seed", list(range(256))) -def test_vocab_parallel_embedding_indices(tp_size, seed): - random.seed(seed) - vocab_size = random.randint(4000, 64000) - added_vocab_size = random.randint(0, 1024) - org_vocab_size = vocab_size - added_vocab_size - last_org_vocab_end_index = 0 - last_added_vocab_end_index = org_vocab_size - computed_vocab_size = 0 - computed_org_vocab_size = 0 - computed_added_vocab_size = 0 - vocab_size_padded = -1 - - all_org_tokens: List[int] = [] - all_added_tokens: List[int] = [] - token_ids: List[int] = [] - - for tp_rank in range(tp_size): - with patch( - "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", - return_value=tp_rank - ), patch( - "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", - return_value=tp_size): - vocab_embedding = VocabParallelEmbedding( - vocab_size, 1, org_num_embeddings=org_vocab_size) - vocab_size_padded = vocab_embedding.num_embeddings_padded - shard_indices = vocab_embedding.shard_indices - # Assert that the ranges are contiguous - assert shard_indices.org_vocab_start_index == last_org_vocab_end_index - assert (shard_indices.added_vocab_start_index == - last_added_vocab_end_index) - - # Ensure that we are not exceeding the vocab size - computed_vocab_size += shard_indices.num_elements_padded - computed_org_vocab_size += shard_indices.num_org_elements - computed_added_vocab_size += shard_indices.num_added_elements - - # Ensure that the ranges are not overlapping - all_org_tokens.extend( - range(shard_indices.org_vocab_start_index, - shard_indices.org_vocab_end_index)) - all_added_tokens.extend( - range(shard_indices.added_vocab_start_index, - shard_indices.added_vocab_end_index)) - - token_ids.extend( - range(shard_indices.org_vocab_start_index, - shard_indices.org_vocab_end_index)) - token_ids.extend([-1] * (shard_indices.num_org_elements_padded - - shard_indices.num_org_elements)) - token_ids.extend( - range(shard_indices.added_vocab_start_index, - shard_indices.added_vocab_end_index)) - token_ids.extend([-1] * (shard_indices.num_added_elements_padded - - shard_indices.num_added_elements)) - - last_org_vocab_end_index = shard_indices.org_vocab_end_index - last_added_vocab_end_index = shard_indices.added_vocab_end_index - - assert computed_vocab_size == vocab_size_padded - assert computed_org_vocab_size == org_vocab_size - assert computed_added_vocab_size == added_vocab_size - - # Ensure that the ranges are not overlapping - assert len(all_org_tokens) == len(set(all_org_tokens)) - assert len(all_added_tokens) == len(set(all_added_tokens)) - assert not set(all_org_tokens).intersection(set(all_added_tokens)) - - token_ids_tensor = torch.tensor(token_ids, dtype=torch.long) - reindex_mapping = vocab_embedding.get_sharded_to_full_mapping() - assert reindex_mapping is not None or tp_size == 1 - if reindex_mapping is not None: - reindexed_token_ids = token_ids_tensor[reindex_mapping] - expected = torch.tensor(list(range(0, vocab_size))) - assert reindexed_token_ids[:vocab_size].equal(expected) - assert torch.all(reindexed_token_ids[vocab_size:] == -1) - - -def test_get_masked_input_and_mask(): - x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) - - # base tp 1 case, no padding - modified_x, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=8, - added_vocab_start_index=8, - added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(x, modified_x) - - # tp 2 case, no padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=4, - added_vocab_start_index=8, - added_vocab_end_index=10, - num_org_vocab_padding=0) - modified_x_rank_1, _ = get_masked_input_and_mask( - x, - org_vocab_start_index=4, - org_vocab_end_index=8, - added_vocab_start_index=10, - added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])) - - # tp 4 case, no padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=2, - added_vocab_start_index=8, - added_vocab_end_index=9, - num_org_vocab_padding=0) - modified_x_rank_1, _ = get_masked_input_and_mask(x, - org_vocab_start_index=2, - org_vocab_end_index=4, - added_vocab_start_index=9, - added_vocab_end_index=10, - num_org_vocab_padding=0) - modified_x_rank_2, _ = get_masked_input_and_mask( - x, - org_vocab_start_index=4, - org_vocab_end_index=6, - added_vocab_start_index=10, - added_vocab_end_index=11, - num_org_vocab_padding=0) - modified_x_rank_3, _ = get_masked_input_and_mask( - x, - org_vocab_start_index=6, - org_vocab_end_index=8, - added_vocab_start_index=11, - added_vocab_end_index=12, - num_org_vocab_padding=0) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])) - assert torch.equal(modified_x_rank_2, - torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])) - assert torch.equal(modified_x_rank_3, - torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])) - - # base tp 1 case, with padding - modified_x, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=8, - added_vocab_start_index=8, - added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x, - torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])) - - # tp 2 case, with padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=4, - added_vocab_start_index=8, - added_vocab_end_index=10, - num_org_vocab_padding=2) - modified_x_rank_1, _ = get_masked_input_and_mask( - x, - org_vocab_start_index=4, - org_vocab_end_index=8, - added_vocab_start_index=10, - added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])) - - # tp 4 case, with padding - modified_x_rank_0, _ = get_masked_input_and_mask(x, - org_vocab_start_index=0, - org_vocab_end_index=2, - added_vocab_start_index=8, - added_vocab_end_index=9, - num_org_vocab_padding=2) - modified_x_rank_1, _ = get_masked_input_and_mask(x, - org_vocab_start_index=2, - org_vocab_end_index=4, - added_vocab_start_index=9, - added_vocab_end_index=10, - num_org_vocab_padding=2) - modified_x_rank_2, _ = get_masked_input_and_mask( - x, - org_vocab_start_index=4, - org_vocab_end_index=6, - added_vocab_start_index=10, - added_vocab_end_index=11, - num_org_vocab_padding=2) - modified_x_rank_3, _ = get_masked_input_and_mask( - x, - org_vocab_start_index=6, - org_vocab_end_index=8, - added_vocab_start_index=11, - added_vocab_end_index=12, - num_org_vocab_padding=2) - assert torch.equal(modified_x_rank_0, - torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])) - assert torch.equal(modified_x_rank_1, - torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])) - assert torch.equal(modified_x_rank_2, - torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) - assert torch.equal(modified_x_rank_3, - torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])) From f31b7d1c4636767fab8219fb651bfaba0f5300e8 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 31 Jan 2025 17:58:01 +0000 Subject: [PATCH 027/186] Linting Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 4 +- vllm/lora/ops/torch_ops/lora_ops.py | 9 +- vllm/lora/ops/xla_ops/__init__.py | 4 +- vllm/lora/ops/xla_ops/lora_ops.py | 3 + vllm/lora/punica_wrapper/punica_tpu.py | 22 ++-- vllm/platforms/tpu.py | 2 +- vllm/worker/tpu_model_runner.py | 137 +++++++++++++------------ vllm/worker/tpu_worker.py | 9 +- 8 files changed, 99 insertions(+), 91 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 16c5771a5123..1a1ffda03f67 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1092,7 +1092,7 @@ def _get_logits( lora_logits[-1] = float("-inf") lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded - + if current_platform.is_tpu(): # Because nan_to_num_ doesn't work with actual -inf values on TPU neg_inf = torch.finfo(lora_logits.dtype).min @@ -1100,7 +1100,7 @@ def _get_logits( else: neg_inf = float("-inf") pos_inf = float("inf") - + lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], diff --git a/vllm/lora/ops/torch_ops/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py index 30240c5e0bc9..1a43f22215e2 100644 --- a/vllm/lora/ops/torch_ops/lora_ops.py +++ b/vllm/lora/ops/torch_ops/lora_ops.py @@ -32,7 +32,8 @@ def bgmv_expand(inputs: torch.Tensor, inputs = inputs.to(dtype=output_tensor.dtype) # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -75,7 +76,8 @@ def bgmv_shrink(inputs: torch.Tensor, inputs = inputs.to(dtype=output_tensor.dtype) # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] @@ -112,7 +114,8 @@ def bgmv_expand_slice(inputs: torch.Tensor, if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape((batch_size, input_size, 1))).reshape((batch_size, output_size)) + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) if add_inputs: output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py index 4785af8520d3..632a5d0274b0 100644 --- a/vllm/lora/ops/xla_ops/__init__.py +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -1,7 +1,7 @@ from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand # noqa: F401 from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, - sgmv_expand, sgmv_expand_slice, - sgmv_shrink) + sgmv_expand, sgmv_expand_slice, + sgmv_shrink) __all__ = [ "bgmv_expand", diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index d6c630880644..a52ac51b43c9 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,6 +1,8 @@ import torch + from ..torch_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink + def sgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -17,6 +19,7 @@ def sgmv_expand(inputs: torch.Tensor, bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs) + def sgmv_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index b6739bd97bdb..b831b4878b02 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -2,9 +2,8 @@ import torch -from vllm.lora.ops.xla_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.xla_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, sgmv_shrink) from .punica_base import PunicaWrapperBase @@ -139,17 +138,18 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], """ x = x.view(-1, x.shape[-1]) - - shrink_fun: Callable = (self._shrink_prefill if self.is_prefill else self._shrink_decode) - + + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): y_s = y[slice_idx] lora_s = lora_a_stacked[slice_idx] - + y_org = y_s y_s = y_s.view(-1, y_s.shape[-1]) - + shrink_fun(y_s, x, lora_s, scale) y_s = y_s.view_as(y_org) @@ -181,8 +181,10 @@ def add_expand(self, output_slices (Tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ - expand_slice_fun: Callable = (self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode) - + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + y_org = y y = y.view(-1, y.shape[-1]) offset_left = offset_start diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 0ef558850c0d..f57a30cd557a 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -61,7 +61,7 @@ def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on TPU.") return False - + @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 248bcc4627c1..bb973f883248 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -3,8 +3,8 @@ import enum import time from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Set, - Type, Union) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, + Tuple, Type, Union) from unittest.mock import patch import numpy as np @@ -45,6 +45,7 @@ # This can significantly affect the performance if too large. _MAX_NUM_SAMPLES = 128 + class ExecutionMode(enum.Enum): PREFILL = enum.auto() DECODE = enum.auto() @@ -53,6 +54,7 @@ class ExecutionMode(enum.Enum): def is_prefill(self) -> bool: return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) + @dataclass(frozen=True) class ModelInputForTPU(ModelRunnerInputBase): token_ids: torch.Tensor @@ -126,7 +128,7 @@ def __init__( False, ) self.cached_step_outputs: List[torch.Tensor] = [] - + # LoRA support self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None @@ -162,14 +164,14 @@ def load_model(self) -> None: model = model.eval() xm.wait_device_ops() self.model = model - + if self.lora_config: assert supports_lora( self.model ), f"{self.model.__class__.__name__} does not support LoRA yet." max_pos_embeddings = self.model.config.max_position_embeddings - + self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, @@ -181,12 +183,12 @@ def load_model(self) -> None: max_position_embeddings=max_pos_embeddings, ) self.model = self.lora_manager.create_lora_manager(self.model) - + self.model = ModelWrapper(self.model) self.model = torch.compile(self.model, - backend="openxla", - fullgraph=True, - dynamic=False) + backend="openxla", + fullgraph=True, + dynamic=False) def get_model(self) -> nn.Module: return self.model.model @@ -278,12 +280,13 @@ def _dummy_run( t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 - - # Create a series of dummy loras and requests for them. Make to fill all lora slots. + + # Create a series of dummy loras and requests for them. + # Make to fill all lora slots. if self.lora_config: dummy_lora_requests: Set[LoRARequest] = set() dummy_lora_mapping: LoRAMapping - + assert self.lora_manager is not None with self.lora_manager.dummy_lora_cache(): for lora_id in range(1, self.lora_config.max_loras + 1): @@ -292,12 +295,13 @@ def _dummy_run( lora_int_id=lora_id, lora_path="/not/a/real/path", ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=self.lora_config.max_lora_rank) + self.lora_manager.add_dummy_lora( + dummy_lora_request, + rank=self.lora_config.max_lora_rank) dummy_lora_requests.add(dummy_lora_request) dummy_lora_mapping = LoRAMapping( - [lora_id] * batch_size * seq_len, [lora_id] * batch_size, is_prefill=exec_mode.is_prefill() - ) + [lora_id] * batch_size * seq_len, [lora_id] * batch_size, + is_prefill=exec_mode.is_prefill()) self.set_active_loras(dummy_lora_requests, dummy_lora_mapping) # NOTE(woosuk): There are two stages of compilation: torch.compile and @@ -637,55 +641,52 @@ def prepare_model_input( list(metadata.seq_data.keys()) for metadata in seq_group_metadata_list ] - + lora_inputs = [] if self.load_config is not None: - lora_inputs = self._prepare_lora_input(seq_group_metadata_list, is_prompt, padded_batch_size) - - return ModelInputForTPU( - token_ids=input_tokens, - position_ids=input_positions, - attn_metadata=attn_metadata, - input_lens=input_lens, - t=t, - p=p, - num_samples=num_samples, - n=n, - seq_groups=seq_groups, - lora_inputs=lora_inputs - ) - + lora_inputs = self._prepare_lora_input(seq_group_metadata_list, + is_prompt, + padded_batch_size) + + return ModelInputForTPU(token_ids=input_tokens, + position_ids=input_positions, + attn_metadata=attn_metadata, + input_lens=input_lens, + t=t, + p=p, + num_samples=num_samples, + n=n, + seq_groups=seq_groups, + lora_inputs=lora_inputs) + def _prepare_lora_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], - is_prefill: bool, - padded_batch_size: int) -> List[Tuple[Set[LoRARequest], LoRAMapping]]: + is_prefill: bool, padded_batch_size: int + ) -> List[Tuple[Set[LoRARequest], LoRAMapping]]: """ - Prepares a list of LoRA inputs. If we're decoding then the list will only have 1 item, - otherwise there'll be an item for each sequence + Prepares a list of LoRA inputs. If we're decoding then the list will + only have 1 item, otherwise there'll be an item for each sequence """ - + lora_input = [] if is_prefill: for seq in seq_group_metadata_list: lora_id = seq.lora_int_id query_len = seq.token_chunk_size padded_query_len = _get_padded_prefill_len(query_len) - + index_mapping = [lora_id] * padded_query_len prompt_mapping = [lora_id] - + lora_request = set() if seq.lora_request is not None: lora_request.add(seq.lora_request) - - lora_input.append(( - lora_request, - LoRAMapping( - index_mapping=tuple(index_mapping), - prompt_mapping=tuple(prompt_mapping), - is_prefill=True - ) - )) + + lora_input.append( + (lora_request, + LoRAMapping(index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=True))) else: lora_request = set() index_mapping = [] @@ -695,22 +696,21 @@ def _prepare_lora_input( index_mapping += [lora_id] prompt_mapping += [lora_id] - + if seq.lora_request is not None: lora_request.add(seq.lora_request) - - index_mapping += [0] * (padded_batch_size - len(seq_group_metadata_list)) - prompt_mapping += [0] * (padded_batch_size - len(seq_group_metadata_list)) - - lora_input.append(( - lora_request, - LoRAMapping( - index_mapping=tuple(index_mapping), - prompt_mapping=tuple(prompt_mapping), - is_prefill=False - ) - )) - + + index_mapping += [0] * (padded_batch_size - + len(seq_group_metadata_list)) + prompt_mapping += [0] * (padded_batch_size - + len(seq_group_metadata_list)) + + lora_input.append( + (lora_request, + LoRAMapping(index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=False))) + return lora_input def make_model_input_from_broadcasted_tensor_dict( @@ -728,7 +728,7 @@ def execute_model( num_steps: int = 1, ) -> List[SamplerOutput]: assert intermediate_tensors is None - + if not model_input.is_first_multi_step: if not model_input.is_last_step: return [] @@ -804,12 +804,12 @@ def execute_model( input_lens = model_input.input_lens[i:i + 1].to(self.device) t = model_input.t[i:i + 1].to(self.device) p = model_input.p[i:i + 1].to(self.device) - + if self.lora_config is not None: assert len(model_input.lora_inputs) == batch_size lora_requests, lora_mapping = model_input.lora_inputs[i] self.set_active_loras(lora_requests, lora_mapping) - + with set_forward_context(model_input.attn_metadata, self.vllm_config, model_input.virtual_engine): @@ -859,12 +859,12 @@ def execute_model( t = model_input.t.to(self.device) p = model_input.p.to(self.device) input_lens = model_input.input_lens.to(self.device) - + if self.lora_config is not None: assert len(model_input.lora_inputs) == 1 lora_requests, lora_mapping = model_input.lora_inputs[0] self.set_active_loras(lora_requests, lora_mapping) - + for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping with set_forward_context(model_input.attn_metadata, @@ -906,7 +906,7 @@ def execute_model( sampler_output = _make_decode_output(next_token_ids, model_input.seq_groups) return [sampler_output] - + def remove_all_loras(self): if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") @@ -938,6 +938,7 @@ def list_loras(self) -> Set[int]: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.list_adapters() + class ModelWrapper(nn.Module): def __init__(self, model: nn.Module): diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index eb04479a10c3..c73327b7c8db 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import List, Optional, Tuple, Union, Set +from typing import List, Optional, Set, Tuple, Union import torch import torch_xla.core.xla_model as xm @@ -18,8 +18,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - WorkerBase, +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) logger = init_logger(__name__) @@ -289,10 +288,10 @@ def execute_worker(self, worker_input: WorkerInput) -> None: if src_indices.numel() > 0: attn_backend.copy_blocks(self.tpu_cache, (src_indices, dst_indices)) - + def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) - + def remove_lora(self, lora_id: int) -> bool: return self.model_runner.remove_lora(lora_id) From 87ff73e8a70f23419cd6d5b063e35a639e437fab Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 3 Feb 2025 17:51:52 +0000 Subject: [PATCH 028/186] Fixed import error Signed-off-by: Akshat Tripathi --- vllm/lora/ops/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 vllm/lora/ops/__init__.py diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From 96c3ddeb111d8a8c05247b351061cd11a61aaceb Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 4 Feb 2025 14:15:54 +0000 Subject: [PATCH 029/186] lint Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/__init__.py | 2 ++ vllm/lora/ops/xla_ops/lora_ops.py | 3 +++ vllm/lora/punica_wrapper/punica_tpu.py | 4 +++- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py index 632a5d0274b0..67ffde460755 100644 --- a/vllm/lora/ops/xla_ops/__init__.py +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand # noqa: F401 from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index a52ac51b43c9..b664b93fbf6f 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + import torch from ..torch_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink @@ -35,6 +37,7 @@ def sgmv_shrink( exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) + print("SGMV", lora_indices_tensor, lora_a_weights) bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index b831b4878b02..84245e82eb8a 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -1,3 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + from typing import Callable, Optional, Tuple, Union import torch @@ -222,7 +224,7 @@ def add_lora_embedding(self, add_inputs (bool): Default to True. """ - # Embedding layer only need expand op + # Embedding layer only needs the expand op expand_fun: Callable = (self._expand_prefill if self.is_prefill else self._expand_decode) expand_fun(y, x, lora_b_stacked, add_inputs) From 4e72edea6da3aa4713facef1ec928de096295bd6 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 3 Mar 2025 18:04:32 +0000 Subject: [PATCH 030/186] Abstracted out infinity values Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 14 +++++--------- vllm/platforms/interface.py | 7 +++++++ vllm/platforms/tpu.py | 6 +++++- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 1a1ffda03f67..0efefa3585c2 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1089,18 +1089,14 @@ def _get_logits( torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1]) - lora_logits[-1] = float("-inf") + + neg_inf, pos_inf = current_platform.get_infinity_values( + lora_logits.dtype) + + lora_logits[-1] = neg_inf lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded - if current_platform.is_tpu(): - # Because nan_to_num_ doesn't work with actual -inf values on TPU - neg_inf = torch.finfo(lora_logits.dtype).min - pos_inf = torch.finfo(lora_logits.dtype).max - else: - neg_inf = float("-inf") - pos_inf = float("inf") - lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index e7e55e11775c..30a27fea0872 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -323,6 +323,13 @@ def get_punica_wrapper(cls) -> str: """ raise NotImplementedError + @classmethod + def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + """ + Return the platform specific values for (-inf, inf) + """ + return float("-inf"), float("inf") + @classmethod def get_device_communicator_cls(cls) -> str: """ diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index f57a30cd557a..6dde25a6f065 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Tuple import torch @@ -66,6 +66,10 @@ def is_pin_memory_available(cls): def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" + @classmethod + def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + return torch.finfo(dtype).min, torch.finfo(dtype).max + @classmethod def inference_mode(cls): return torch.no_grad() From e4d35cee08cf85bd6dfa6373bfafc9fdb7f4ad05 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 7 Feb 2025 18:48:08 +0000 Subject: [PATCH 031/186] Moved and modified bgmv ops from the cpu backend to the tpu backend, because xla doesn't allow partial updates Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 90 ++++++++++++++++++++++++-- vllm/lora/punica_wrapper/punica_tpu.py | 9 ++- 2 files changed, 92 insertions(+), 7 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index b664b93fbf6f..308d361fe7eb 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -2,8 +2,6 @@ import torch -from ..torch_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink - def sgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, @@ -22,6 +20,37 @@ def sgmv_expand(inputs: torch.Tensor, add_inputs) +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + outputs = torch.cat( + (outputs, + torch.zeros((batch_size, output_tensor.shape[1] - outputs.shape[1]), + device=outputs.device)), + dim=1) + + if add_inputs: + output_tensor += outputs[:limit, :] + else: + output_tensor = outputs[:limit, :] + + def sgmv_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, @@ -37,11 +66,28 @@ def sgmv_shrink( exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) - print("SGMV", lora_indices_tensor, lora_a_weights) bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling) +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + batch_size, output_size, input_size = selected_loras.shape + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) + + output_tensor = scaling * outputs[:] + + def sgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -53,9 +99,45 @@ def sgmv_expand_slice(inputs: torch.Tensor, token_nums: int, slice_offset: int, slice_size: int, + total_size: int, add_inputs: bool = False): exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, - slice_offset, slice_size, add_inputs) + slice_offset, slice_size, total_size, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + total_size: int, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + + inputs = inputs.to(dtype=output_tensor.dtype) + + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + + batch_size, output_size, input_size = selected_loras.shape + + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) + + outputs = torch.cat(( + torch.zeros((batch_size, slice_offset), device=outputs.device), + outputs, + torch.zeros((batch_size, total_size - (slice_offset + slice_size)), + device=outputs.device), + ), + dim=1) + + if add_inputs: + output_tensor += outputs + else: + output_tensor = outputs diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 84245e82eb8a..920aacfbf8e8 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -89,6 +89,7 @@ def _expand_slice_prefill( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, + y_total_size: int, add_inputs: bool, ): #No LoRA request, so return directly @@ -101,6 +102,7 @@ def _expand_slice_prefill( *self.prefill_metadata, y_offset, y_slice_size, + y_total_size, add_inputs, ) @@ -111,12 +113,13 @@ def _expand_slice_decode( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, + y_total_size: int, add_inputs: bool, ): if self.no_lora: return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_inputs) + y_slice_size, y_total_size, add_inputs) def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], @@ -161,7 +164,6 @@ def add_expand(self, lora_b_stacked: Tuple[torch.Tensor, ...], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], output_slices: Tuple[int, ...], - offset_start: int = 0, add_inputs=True, **kwargs) -> None: """ @@ -189,7 +191,7 @@ def add_expand(self, y_org = y y = y.view(-1, y.shape[-1]) - offset_left = offset_start + offset_left = 0 if lora_bias_stacked is not None: self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) @@ -200,6 +202,7 @@ def add_expand(self, lora_b_stacked[slice_idx], offset_left, output_slices[slice_idx], + y_total_size=sum(output_slices), add_inputs=add_inputs, ) offset_left += output_slices[slice_idx] From 3cf06807d94b9ecb01de44af53d856f95642de74 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 7 Feb 2025 18:55:49 +0000 Subject: [PATCH 032/186] Removed total_size for linting Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 11 +++++------ vllm/lora/punica_wrapper/punica_tpu.py | 6 +----- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 308d361fe7eb..e494a2fed52d 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -85,7 +85,7 @@ def bgmv_shrink(inputs: torch.Tensor, outputs = (selected_loras @ inputs.reshape( (batch_size, input_size, 1))).reshape((batch_size, output_size)) - output_tensor = scaling * outputs[:] + output_tensor = scaling * outputs def sgmv_expand_slice(inputs: torch.Tensor, @@ -99,13 +99,12 @@ def sgmv_expand_slice(inputs: torch.Tensor, token_nums: int, slice_offset: int, slice_size: int, - total_size: int, add_inputs: bool = False): exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, - slice_offset, slice_size, total_size, add_inputs) + slice_offset, slice_size, add_inputs) def bgmv_expand_slice(inputs: torch.Tensor, @@ -114,7 +113,6 @@ def bgmv_expand_slice(inputs: torch.Tensor, lora_indices_tensor: torch.Tensor, slice_offset: int, slice_size: int, - total_size: int, add_inputs: bool = True): selected_loras = lora_b_weights[lora_indices_tensor].to( dtype=output_tensor.dtype) @@ -132,8 +130,9 @@ def bgmv_expand_slice(inputs: torch.Tensor, outputs = torch.cat(( torch.zeros((batch_size, slice_offset), device=outputs.device), outputs, - torch.zeros((batch_size, total_size - (slice_offset + slice_size)), - device=outputs.device), + torch.zeros( + (batch_size, output_tensor.shape[1] - (slice_offset + slice_size)), + device=outputs.device), ), dim=1) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 920aacfbf8e8..4b5642033ff7 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -89,7 +89,6 @@ def _expand_slice_prefill( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, - y_total_size: int, add_inputs: bool, ): #No LoRA request, so return directly @@ -102,7 +101,6 @@ def _expand_slice_prefill( *self.prefill_metadata, y_offset, y_slice_size, - y_total_size, add_inputs, ) @@ -113,13 +111,12 @@ def _expand_slice_decode( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, - y_total_size: int, add_inputs: bool, ): if self.no_lora: return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, y_total_size, add_inputs) + y_slice_size, add_inputs) def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], @@ -202,7 +199,6 @@ def add_expand(self, lora_b_stacked[slice_idx], offset_left, output_slices[slice_idx], - y_total_size=sum(output_slices), add_inputs=add_inputs, ) offset_left += output_slices[slice_idx] From a8ab0c977687fa8ed53e6edf4ce364eea659430b Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 7 Feb 2025 19:04:58 +0000 Subject: [PATCH 033/186] Reverted changes to torch_ops Signed-off-by: Akshat Tripathi --- vllm/lora/ops/torch_ops/lora_ops.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/vllm/lora/ops/torch_ops/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py index 1a43f22215e2..af79f98415cb 100644 --- a/vllm/lora/ops/torch_ops/lora_ops.py +++ b/vllm/lora/ops/torch_ops/lora_ops.py @@ -30,10 +30,7 @@ def bgmv_expand(inputs: torch.Tensor, if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -74,10 +71,7 @@ def bgmv_shrink(inputs: torch.Tensor, if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] @@ -113,9 +107,7 @@ def bgmv_expand_slice(inputs: torch.Tensor, inputs = inputs.to(dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(dim=1) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) if add_inputs: output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] From d73f1cea8c18fce976b0985e82669ca6bfd1c320 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 7 Feb 2025 19:11:20 +0000 Subject: [PATCH 034/186] Lint Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 4b5642033ff7..3b7a6dad035d 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -161,6 +161,7 @@ def add_expand(self, lora_b_stacked: Tuple[torch.Tensor, ...], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], output_slices: Tuple[int, ...], + offset_start: int = 0, add_inputs=True, **kwargs) -> None: """ From e01d9a4cfc6051ca403a1fef7a25039a21f2946f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 3 Mar 2025 18:04:49 +0000 Subject: [PATCH 035/186] Replaced in-place buffer updates with direct returning Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 32 +++++++---- vllm/lora/ops/xla_ops/lora_ops.py | 25 +++++---- vllm/lora/punica_wrapper/punica_base.py | 10 ++-- vllm/lora/punica_wrapper/punica_tpu.py | 74 +++++++++++++------------ vllm/platforms/interface.py | 5 ++ vllm/platforms/tpu.py | 4 ++ 6 files changed, 88 insertions(+), 62 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 0efefa3585c2..e527addc99f9 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -258,10 +258,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings.shape[1], -1, ) - self.punica_wrapper.add_lora_embedding(full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) + + lora_output = self.punica_wrapper.add_lora_embedding( + full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + if not current_platform.can_update_inplace(): + full_output = lora_output + return full_output.view_as(full_output_org) @classmethod @@ -395,10 +400,12 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, - self.lora_b_stacked, - self.lora_bias_stacked, 1.0, - self.output_slices) + lora_output = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.lora_bias_stacked, 1.0, self.output_slices) + if not current_platform.can_update_inplace(): + output = lora_output + return output @classmethod @@ -1113,9 +1120,12 @@ def _get_logits( lora_logits.shape[1]] = lora_logits # LogitsProcessorWithLoRA always using bgmv - self.punica_wrapper.add_lora_logits(logits, hidden_states, - self.lora_a_stacked, - self.lora_b_stacked, 1.0) + lora_output = self.punica_wrapper.add_lora_logits( + logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, + 1.0) + + if not current_platform.can_update_inplace(): + logits = lora_output # Remove paddings in vocab (if any). logits = logits[:, :self.base_layer.vocab_size] diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index e494a2fed52d..7ac7d16fbf88 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -16,8 +16,8 @@ def sgmv_expand(inputs: torch.Tensor, exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) - bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, - add_inputs) + return bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) def bgmv_expand(inputs: torch.Tensor, @@ -46,9 +46,9 @@ def bgmv_expand(inputs: torch.Tensor, dim=1) if add_inputs: - output_tensor += outputs[:limit, :] + return output_tensor + outputs[:limit, :] else: - output_tensor = outputs[:limit, :] + return outputs[:limit, :] def sgmv_shrink( @@ -66,8 +66,8 @@ def sgmv_shrink( exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) - bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, - scaling) + return bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, + scaling) def bgmv_shrink(inputs: torch.Tensor, @@ -75,6 +75,7 @@ def bgmv_shrink(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].to( dtype=output_tensor.dtype) if len(selected_loras.shape) == 4: @@ -85,7 +86,7 @@ def bgmv_shrink(inputs: torch.Tensor, outputs = (selected_loras @ inputs.reshape( (batch_size, input_size, 1))).reshape((batch_size, output_size)) - output_tensor = scaling * outputs + return scaling * outputs def sgmv_expand_slice(inputs: torch.Tensor, @@ -103,8 +104,9 @@ def sgmv_expand_slice(inputs: torch.Tensor, exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) - bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, - slice_offset, slice_size, add_inputs) + return bgmv_expand_slice(inputs, lora_b_weights, output_tensor, + exploded_indices, slice_offset, slice_size, + add_inputs) def bgmv_expand_slice(inputs: torch.Tensor, @@ -114,6 +116,7 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( dtype=output_tensor.dtype) @@ -137,6 +140,6 @@ def bgmv_expand_slice(inputs: torch.Tensor, dim=1) if add_inputs: - output_tensor += outputs + return output_tensor + outputs else: - output_tensor = outputs + return outputs diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 94fa3f27ab60..0332867055b7 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -48,7 +48,7 @@ def add_shrink( lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. """ @@ -66,7 +66,7 @@ def add_expand( offset_start: int = 0, add_inputs=True, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Performs GEMM and bias addition for multiple slices of lora_b. """ @@ -80,7 +80,7 @@ def add_lora_embedding( lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA, and this layer only requires the expand operation. @@ -98,7 +98,7 @@ def add_lora_linear(self, output_slices: Tuple[int, ...], *, buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. """ @@ -114,7 +114,7 @@ def add_lora_logits(self, scale, *, buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. """ diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 3b7a6dad035d..602ec824853b 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -34,7 +34,7 @@ def _shrink_prefill( #No LoRA request, so return directly if self.no_lora: return - sgmv_shrink( + return sgmv_shrink( x, w_t_all, y, @@ -51,7 +51,7 @@ def _shrink_decode( ): if self.no_lora: return - bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) def _expand_prefill( self, @@ -63,7 +63,7 @@ def _expand_prefill( #No LoRA request, so return directly if self.no_lora: return - sgmv_expand( + return sgmv_expand( x, w_t_all, y, @@ -80,7 +80,7 @@ def _expand_decode( ): if self.no_lora: return - bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) def _expand_slice_prefill( self, @@ -90,11 +90,11 @@ def _expand_slice_prefill( y_offset: int, y_slice_size: int, add_inputs: bool, - ): + ) -> torch.Tensor: #No LoRA request, so return directly if self.no_lora: return - sgmv_expand_slice( + return sgmv_expand_slice( x, w_t_all, y, @@ -112,15 +112,15 @@ def _expand_slice_decode( y_offset: int, y_slice_size: int, add_inputs: bool, - ): + ) -> torch.Tensor: if self.no_lora: return - bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_inputs) + return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, + y_offset, y_slice_size, add_inputs) def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, **kwargs): + scale: float, **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. When `is_prefill is` true, it indicates that it is currently the @@ -144,6 +144,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], shrink_fun: Callable = (self._shrink_prefill if self.is_prefill else self._shrink_decode) + new_y = [] # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): y_s = y[slice_idx] @@ -152,8 +153,10 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y_org = y_s y_s = y_s.view(-1, y_s.shape[-1]) - shrink_fun(y_s, x, lora_s, scale) + y_s = shrink_fun(y_s, x, lora_s, scale) y_s = y_s.view_as(y_org) + new_y.append(y_s) + return tuple(new_y) def add_expand(self, y: torch.Tensor, @@ -163,7 +166,7 @@ def add_expand(self, output_slices: Tuple[int, ...], offset_start: int = 0, add_inputs=True, - **kwargs) -> None: + **kwargs) -> torch.Tensor: """ Performs GEMM and bias addition for multiple slices of lora_b. @@ -191,10 +194,10 @@ def add_expand(self, y = y.view(-1, y.shape[-1]) offset_left = 0 if lora_bias_stacked is not None: - self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): - expand_slice_fun( + y = expand_slice_fun( y, x[slice_idx], lora_b_stacked[slice_idx], @@ -203,14 +206,14 @@ def add_expand(self, add_inputs=add_inputs, ) offset_left += output_slices[slice_idx] - y = y.view_as(y_org) + return y.view_as(y_org) def add_lora_embedding(self, y: torch.Tensor, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, - **kwargs) -> None: + **kwargs) -> torch.Tensor: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -227,7 +230,7 @@ def add_lora_embedding(self, # Embedding layer only needs the expand op expand_fun: Callable = (self._expand_prefill if self.is_prefill else self._expand_decode) - expand_fun(y, x, lora_b_stacked, add_inputs) + return expand_fun(y, x, lora_b_stacked, add_inputs) def add_lora_linear(self, y: torch.Tensor, @@ -239,7 +242,7 @@ def add_lora_linear(self, output_slices: Tuple[int, ...], *, buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + **kwargs) -> torch.Tensor: """ Applicable to linear-related lora. @@ -279,14 +282,14 @@ def add_lora_linear(self, dtype=torch.float32, device=x.device, ) - self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - self.add_expand(y, - buffer, - lora_b_stacked, - None, - output_slices, - add_inputs=True, - **kwargs) + buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + return self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) def add_lora_logits(self, y: torch.Tensor, @@ -296,7 +299,7 @@ def add_lora_logits(self, scale, *, buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + **kwargs) -> torch.Tensor: """ Applies lora specifically for LogitsProcessorWithLoRA. @@ -323,10 +326,11 @@ def add_lora_logits(self, dtype=torch.float32, device=x.device) # LogitsProcessorWithLoRA always using bgmv. - bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, - lora_b_stacked, - y, - self.sampler_indices, - add_inputs=True) - y = y.view_as(y_org) + buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, + scale) + y = bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + return y.view_as(y_org) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 30a27fea0872..3477b1b3fa01 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -330,6 +330,11 @@ def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: """ return float("-inf"), float("inf") + @classmethod + def can_update_inplace(cls) -> bool: + """Checks if the platform allows inplace memory updates""" + return True + @classmethod def get_device_communicator_cls(cls) -> str: """ diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 6dde25a6f065..0c9d247d4a5d 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -70,6 +70,10 @@ def get_punica_wrapper(cls) -> str: def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: return torch.finfo(dtype).min, torch.finfo(dtype).max + @classmethod + def can_update_inplace(cls): + return False + @classmethod def inference_mode(cls): return torch.no_grad() From 0c1bfb94febc281cf033d16cdb0b353cc11123fb Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 11 Feb 2025 14:51:29 +0000 Subject: [PATCH 036/186] PunicaWrapperTPU now returns unchanged buffer if no loras are needed Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 602ec824853b..90058cd404d0 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -33,7 +33,7 @@ def _shrink_prefill( ): #No LoRA request, so return directly if self.no_lora: - return + return y return sgmv_shrink( x, w_t_all, @@ -50,7 +50,7 @@ def _shrink_decode( scale: float, ): if self.no_lora: - return + return y return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) def _expand_prefill( @@ -62,7 +62,7 @@ def _expand_prefill( ): #No LoRA request, so return directly if self.no_lora: - return + return y return sgmv_expand( x, w_t_all, @@ -79,7 +79,7 @@ def _expand_decode( add_inputs: bool, ): if self.no_lora: - return + return y return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) def _expand_slice_prefill( @@ -93,7 +93,7 @@ def _expand_slice_prefill( ) -> torch.Tensor: #No LoRA request, so return directly if self.no_lora: - return + return y return sgmv_expand_slice( x, w_t_all, @@ -114,7 +114,7 @@ def _expand_slice_decode( add_inputs: bool, ) -> torch.Tensor: if self.no_lora: - return + return y return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) From 46ce7fa3055e2e9abf1ee060aead11842c1ea176 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Feb 2025 17:46:03 +0000 Subject: [PATCH 037/186] Simplified TPU prefill Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 15 ---------- vllm/lora/punica_wrapper/punica_tpu.py | 39 +++++++++++++++++++------- vllm/worker/tpu_model_runner.py | 1 + 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 7ac7d16fbf88..69449981b89a 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -6,12 +6,7 @@ def sgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, add_inputs: bool = False): exploded_indices = torch.repeat_interleave(lora_indices_tensor, inputs.size(0)) @@ -55,12 +50,7 @@ def sgmv_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, scaling: float, ): exploded_indices = torch.repeat_interleave(lora_indices_tensor, @@ -92,12 +82,7 @@ def bgmv_shrink(inputs: torch.Tensor, def sgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, slice_offset: int, slice_size: int, add_inputs: bool = False): diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 90058cd404d0..847cdd75a76c 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -141,8 +141,8 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x = x.view(-1, x.shape[-1]) - shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) + # shrink_fun: Callable = (self._shrink_prefill + # if self.is_prefill else self._shrink_decode) new_y = [] # TODO fuse these kernels @@ -153,7 +153,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y_org = y_s y_s = y_s.view(-1, y_s.shape[-1]) - y_s = shrink_fun(y_s, x, lora_s, scale) + y_s = self._shrink_decode(y_s, x, lora_s, scale) y_s = y_s.view_as(y_org) new_y.append(y_s) return tuple(new_y) @@ -186,9 +186,9 @@ def add_expand(self, output_slices (Tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ - expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else - self._expand_slice_decode) + # expand_slice_fun: Callable = (self._expand_slice_prefill + # if self.is_prefill else + # self._expand_slice_decode) y_org = y y = y.view(-1, y.shape[-1]) @@ -197,7 +197,7 @@ def add_expand(self, y = self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): - y = expand_slice_fun( + y = self._expand_slice_decode( y, x[slice_idx], lora_b_stacked[slice_idx], @@ -228,9 +228,9 @@ def add_lora_embedding(self, """ # Embedding layer only needs the expand op - expand_fun: Callable = (self._expand_prefill - if self.is_prefill else self._expand_decode) - return expand_fun(y, x, lora_b_stacked, add_inputs) + # expand_fun: Callable = (self._expand_prefill + # if self.is_prefill else self._expand_decode) + return self._expand_decode(y, x, lora_b_stacked, add_inputs) def add_lora_linear(self, y: torch.Tensor, @@ -334,3 +334,22 @@ def add_lora_logits(self, self.sampler_indices, add_inputs=True) return y.view_as(y_org) + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + self.batch_size = 1 + self._lora_indices_per_batch[:self.batch_size].copy_(token_lora_tensor[:self.batch_size]) + + def set_no_lora(self, no_lora: bool): + self.no_lora = no_lora + + @property + def prefill_metadata( + self + ) -> Tuple[torch.Tensor, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + """ + return (self._lora_indices_per_batch[:self.batch_size],) \ No newline at end of file diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index bb973f883248..43728c725d07 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -917,6 +917,7 @@ def set_active_loras(self, lora_requests: Set[LoRARequest], if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + self.lora_manager._adapter_manager.punica_wrapper.set_no_lora(len(lora_requests) == 0) def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: From 5d0cc375a49c46d9783581b45931b507b1d06c66 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Feb 2025 17:48:13 +0000 Subject: [PATCH 038/186] Removed sgmv kernels from TPU implementation Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 44 ------------ vllm/lora/punica_wrapper/punica_tpu.py | 95 +++----------------------- 2 files changed, 8 insertions(+), 131 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 69449981b89a..483bef186185 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -2,19 +2,6 @@ import torch - -def sgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = False): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, - inputs.size(0)) - - return bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, - add_inputs) - - def bgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -45,21 +32,6 @@ def bgmv_expand(inputs: torch.Tensor, else: return outputs[:limit, :] - -def sgmv_shrink( - inputs: torch.Tensor, - lora_a_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float, -): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, - inputs.size(0)) - - return bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, - scaling) - - def bgmv_shrink(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -78,22 +50,6 @@ def bgmv_shrink(inputs: torch.Tensor, return scaling * outputs - -def sgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = False): - exploded_indices = torch.repeat_interleave(lora_indices_tensor, - inputs.size(0)) - - return bgmv_expand_slice(inputs, lora_b_weights, output_tensor, - exploded_indices, slice_offset, slice_size, - add_inputs) - - def bgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 847cdd75a76c..1b8e8ed30e5b 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -4,8 +4,7 @@ import torch -from vllm.lora.ops.xla_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, - sgmv_expand, sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.xla_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink) from .punica_base import PunicaWrapperBase @@ -24,25 +23,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - def _shrink_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - #No LoRA request, so return directly - if self.no_lora: - return y - return sgmv_shrink( - x, - w_t_all, - y, - *self.prefill_metadata, - scale, - ) - - def _shrink_decode( + def shrink( self, y: torch.Tensor, x: torch.Tensor, @@ -53,25 +34,7 @@ def _shrink_decode( return y return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - def _expand_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_inputs: bool, - ): - #No LoRA request, so return directly - if self.no_lora: - return y - return sgmv_expand( - x, - w_t_all, - y, - *self.prefill_metadata, - add_inputs, - ) - - def _expand_decode( + def expand( self, y: torch.Tensor, x: torch.Tensor, @@ -82,29 +45,8 @@ def _expand_decode( return y return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) - def _expand_slice_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: int, - y_slice_size: int, - add_inputs: bool, - ) -> torch.Tensor: - #No LoRA request, so return directly - if self.no_lora: - return y - return sgmv_expand_slice( - x, - w_t_all, - y, - *self.prefill_metadata, - y_offset, - y_slice_size, - add_inputs, - ) - def _expand_slice_decode( + def expand_slice( self, y: torch.Tensor, x: torch.Tensor, @@ -141,9 +83,6 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x = x.view(-1, x.shape[-1]) - # shrink_fun: Callable = (self._shrink_prefill - # if self.is_prefill else self._shrink_decode) - new_y = [] # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): @@ -153,7 +92,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y_org = y_s y_s = y_s.view(-1, y_s.shape[-1]) - y_s = self._shrink_decode(y_s, x, lora_s, scale) + y_s = self.shrink(y_s, x, lora_s, scale) y_s = y_s.view_as(y_org) new_y.append(y_s) return tuple(new_y) @@ -186,10 +125,6 @@ def add_expand(self, output_slices (Tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ - # expand_slice_fun: Callable = (self._expand_slice_prefill - # if self.is_prefill else - # self._expand_slice_decode) - y_org = y y = y.view(-1, y.shape[-1]) offset_left = 0 @@ -197,7 +132,7 @@ def add_expand(self, y = self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): - y = self._expand_slice_decode( + y = self.expand_slice( y, x[slice_idx], lora_b_stacked[slice_idx], @@ -228,9 +163,7 @@ def add_lora_embedding(self, """ # Embedding layer only needs the expand op - # expand_fun: Callable = (self._expand_prefill - # if self.is_prefill else self._expand_decode) - return self._expand_decode(y, x, lora_b_stacked, add_inputs) + return self.expand(y, x, lora_b_stacked, add_inputs) def add_lora_linear(self, y: torch.Tensor, @@ -340,16 +273,4 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self._lora_indices_per_batch[:self.batch_size].copy_(token_lora_tensor[:self.batch_size]) def set_no_lora(self, no_lora: bool): - self.no_lora = no_lora - - @property - def prefill_metadata( - self - ) -> Tuple[torch.Tensor, int, int]: - """ - This property provides a convenient way to access the necessary - metadata for prefill-related kernel computations. - 1. lora_indices_per_batch: Tensor of lora indices, and an index of - -1 means no lora should be applied. - """ - return (self._lora_indices_per_batch[:self.batch_size],) \ No newline at end of file + self.no_lora = no_lora \ No newline at end of file From 7590b0e37dfb2837eef1073f4f40daf8de4e707c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Feb 2025 17:51:19 +0000 Subject: [PATCH 039/186] Fix bug Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/__init__.py | 10 ++-------- vllm/worker/tpu_model_runner.py | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py index 67ffde460755..04c399954d14 100644 --- a/vllm/lora/ops/xla_ops/__init__.py +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -1,15 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand # noqa: F401 -from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, - sgmv_expand, sgmv_expand_slice, - sgmv_shrink) +from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink) __all__ = [ "bgmv_expand", "bgmv_expand_slice", - "bgmv_shrink", - "sgmv_expand", - "sgmv_expand_slice", - "sgmv_shrink", + "bgmv_shrink" ] diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 43728c725d07..ee388719bde4 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -917,7 +917,7 @@ def set_active_loras(self, lora_requests: Set[LoRARequest], if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - self.lora_manager._adapter_manager.punica_wrapper.set_no_lora(len(lora_requests) == 0) + self.lora_manager._adapter_manager.punica_wrapper.set_no_lora(len(lora_requests) == 0) # TODO: Cleanup def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: From e7f75b5deb8086b66b1c1bd8a29c8a5f1a7c7ecd Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Feb 2025 17:55:00 +0000 Subject: [PATCH 040/186] Added torch.compiles to PunicaWrapperTPU functions Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 1b8e8ed30e5b..f29ac59c5c4b 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -22,7 +22,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, device: Union[torch.device, str], **kwargs): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - + @torch.compile(backend="openxla") def shrink( self, y: torch.Tensor, @@ -34,6 +34,7 @@ def shrink( return y return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + @torch.compile(backend="openxla") def expand( self, y: torch.Tensor, @@ -45,7 +46,7 @@ def expand( return y return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) - + @torch.compile(backend="openxla") def expand_slice( self, y: torch.Tensor, @@ -60,6 +61,7 @@ def expand_slice( return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) + @torch.compile(backend="openxla") def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs) -> Optional[torch.Tensor]: @@ -97,6 +99,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], new_y.append(y_s) return tuple(new_y) + @torch.compile(backend="openxla") def add_expand(self, y: torch.Tensor, x: Union[Tuple[torch.Tensor, ...], torch.Tensor], @@ -143,6 +146,7 @@ def add_expand(self, offset_left += output_slices[slice_idx] return y.view_as(y_org) + @torch.compile(backend="openxla") def add_lora_embedding(self, y: torch.Tensor, x: torch.Tensor, @@ -165,6 +169,7 @@ def add_lora_embedding(self, # Embedding layer only needs the expand op return self.expand(y, x, lora_b_stacked, add_inputs) + @torch.compile(backend="openxla") def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, @@ -223,7 +228,8 @@ def add_lora_linear(self, output_slices, add_inputs=True, **kwargs) - + + @torch.compile(backend="openxla") def add_lora_logits(self, y: torch.Tensor, x: torch.Tensor, From fe193f7bf3ac539b8ad269403795d89e72a1fdf1 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 14 Feb 2025 11:02:58 +0000 Subject: [PATCH 041/186] Replaced "x[x==-1] = y" with "x = torch.where(x == - 1, y)" Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index dbc2d27c597f..00c3689ef462 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -125,11 +125,11 @@ def convert_mapping( indices[2] * extra_vocab_size, indices[2] * (vocab_size + extra_vocab_size), ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 + embeddings_indices = torch.where(embeddings_indices == -1, embeddings_indices, max_loras - 1) base_indices = indices[1] sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.where(sampler_indices_padded == -1, sampler_indices_padded, max_loras - 1) sampler_indices_padded = torch.arange( 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( sampler_indices_padded * len(sampler_indices_padded)) From 52e3911b876a7d9ea852aa12cf62b4d61636659d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 14 Feb 2025 11:06:34 +0000 Subject: [PATCH 042/186] Revert "Added torch.compiles to PunicaWrapperTPU functions" This reverts commit b78b08898dddcb592480d4179e8d346f78eaabd5. Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index f29ac59c5c4b..1b8e8ed30e5b 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -22,7 +22,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, device: Union[torch.device, str], **kwargs): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - @torch.compile(backend="openxla") + def shrink( self, y: torch.Tensor, @@ -34,7 +34,6 @@ def shrink( return y return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - @torch.compile(backend="openxla") def expand( self, y: torch.Tensor, @@ -46,7 +45,7 @@ def expand( return y return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) - @torch.compile(backend="openxla") + def expand_slice( self, y: torch.Tensor, @@ -61,7 +60,6 @@ def expand_slice( return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) - @torch.compile(backend="openxla") def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs) -> Optional[torch.Tensor]: @@ -99,7 +97,6 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], new_y.append(y_s) return tuple(new_y) - @torch.compile(backend="openxla") def add_expand(self, y: torch.Tensor, x: Union[Tuple[torch.Tensor, ...], torch.Tensor], @@ -146,7 +143,6 @@ def add_expand(self, offset_left += output_slices[slice_idx] return y.view_as(y_org) - @torch.compile(backend="openxla") def add_lora_embedding(self, y: torch.Tensor, x: torch.Tensor, @@ -169,7 +165,6 @@ def add_lora_embedding(self, # Embedding layer only needs the expand op return self.expand(y, x, lora_b_stacked, add_inputs) - @torch.compile(backend="openxla") def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, @@ -228,8 +223,7 @@ def add_lora_linear(self, output_slices, add_inputs=True, **kwargs) - - @torch.compile(backend="openxla") + def add_lora_logits(self, y: torch.Tensor, x: torch.Tensor, From 33a70b097752bc737eb2008a93745ea4e9549d24 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 14 Feb 2025 11:14:25 +0000 Subject: [PATCH 043/186] Fix linting Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/__init__.py | 9 +++------ vllm/lora/punica_wrapper/punica_tpu.py | 12 ++++++------ vllm/worker/tpu_model_runner.py | 3 ++- 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py index 04c399954d14..94062b05d916 100644 --- a/vllm/lora/ops/xla_ops/__init__.py +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -1,9 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink) +from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink) -__all__ = [ - "bgmv_expand", - "bgmv_expand_slice", - "bgmv_shrink" -] +__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 1b8e8ed30e5b..fdbf9cb96ddb 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch -from vllm.lora.ops.xla_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink) +from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink from .punica_base import PunicaWrapperBase @@ -45,7 +45,6 @@ def expand( return y return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) - def expand_slice( self, y: torch.Tensor, @@ -270,7 +269,8 @@ def add_lora_logits(self, def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 - self._lora_indices_per_batch[:self.batch_size].copy_(token_lora_tensor[:self.batch_size]) - + self._lora_indices_per_batch[:self.batch_size].copy_( + token_lora_tensor[:self.batch_size]) + def set_no_lora(self, no_lora: bool): - self.no_lora = no_lora \ No newline at end of file + self.no_lora = no_lora diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index ee388719bde4..88fe864ec4ae 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -917,7 +917,8 @@ def set_active_loras(self, lora_requests: Set[LoRARequest], if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - self.lora_manager._adapter_manager.punica_wrapper.set_no_lora(len(lora_requests) == 0) # TODO: Cleanup + self.lora_manager._adapter_manager.punica_wrapper.set_no_lora( + len(lora_requests) == 0) # TODO: Cleanup def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: From 67446b2d4b875d00616de6d4b51ec0f23f005059 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Feb 2025 12:21:58 +0000 Subject: [PATCH 044/186] Added lora hotswapping test Signed-off-by: Akshat Tripathi --- tests/tpu/test_lora.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 tests/tpu/test_lora.py diff --git a/tests/tpu/test_lora.py b/tests/tpu/test_lora.py new file mode 100644 index 000000000000..ed1553fbb6f2 --- /dev/null +++ b/tests/tpu/test_lora.py @@ -0,0 +1,29 @@ +import vllm +import sys + +from vllm.lora.request import LoRARequest + +def test_lora_hotswapping(): + lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" + lora_requests = [ + LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) + for i in range(1, 5) + ] + + llm = vllm.LLM( + model="Qwen/Qwen2.5-3B-Instruct", + num_scheduler_steps=1, + max_model_len=256, + max_seq_len_to_capture=256, + max_num_seqs=8, + enable_lora=True, + max_loras=2, + max_lora_rank=8 + ) + + prompt = "What is 1+1?" + + for _ in range(10): + for i, req in enumerate(lora_requests): + output = llm.generate(prompt, sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), lora_request=req)[0].outputs[0].text + assert output.strip()[0] == i + 1 From 0db19b1e70283aa461931df677275ab288638121 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Feb 2025 12:45:43 +0000 Subject: [PATCH 045/186] Fixed hotswapping test prompt Signed-off-by: Akshat Tripathi --- tests/tpu/test_lora.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/tpu/test_lora.py b/tests/tpu/test_lora.py index ed1553fbb6f2..d32adda9fe48 100644 --- a/tests/tpu/test_lora.py +++ b/tests/tpu/test_lora.py @@ -1,5 +1,4 @@ import vllm -import sys from vllm.lora.request import LoRARequest @@ -21,7 +20,7 @@ def test_lora_hotswapping(): max_lora_rank=8 ) - prompt = "What is 1+1?" + prompt = "What is 1+1? \n" for _ in range(10): for i, req in enumerate(lora_requests): From a4c3b0a86cebe506074d0d47a529127ea6a650ff Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Feb 2025 13:12:40 +0000 Subject: [PATCH 046/186] Fixed bug in tpu lora test Signed-off-by: Akshat Tripathi --- tests/tpu/test_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tpu/test_lora.py b/tests/tpu/test_lora.py index d32adda9fe48..d3d2c7eb2e1d 100644 --- a/tests/tpu/test_lora.py +++ b/tests/tpu/test_lora.py @@ -25,4 +25,4 @@ def test_lora_hotswapping(): for _ in range(10): for i, req in enumerate(lora_requests): output = llm.generate(prompt, sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), lora_request=req)[0].outputs[0].text - assert output.strip()[0] == i + 1 + assert int(output.strip()[0]) == i + 1 \ No newline at end of file From 9d6c3881e1430ce2ec6b90dbc2500603a4f23405 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 14 Feb 2025 17:21:36 +0000 Subject: [PATCH 047/186] Merged set_no_lora() functionality with _udpate_prefill_metada Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 5 ++--- vllm/worker/tpu_model_runner.py | 2 -- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index fdbf9cb96ddb..1fcedfb61a93 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -271,6 +271,5 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 self._lora_indices_per_batch[:self.batch_size].copy_( token_lora_tensor[:self.batch_size]) - - def set_no_lora(self, no_lora: bool): - self.no_lora = no_lora + # TODO: .item() is extremely inefficient on TPU, so find a way around it + self.no_lora = torch.all(token_lora_tensor == -1).item() diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 88fe864ec4ae..bb973f883248 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -917,8 +917,6 @@ def set_active_loras(self, lora_requests: Set[LoRARequest], if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - self.lora_manager._adapter_manager.punica_wrapper.set_no_lora( - len(lora_requests) == 0) # TODO: Cleanup def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: From 2a9978ec5084c2eac62404ac90e96f476e43e64e Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 14 Feb 2025 17:22:50 +0000 Subject: [PATCH 048/186] Added Multi-LoRA functionality to TPU V1 Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 30 ++++++++++++++++++++++-------- vllm/v1/worker/tpu_worker.py | 4 ++++ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 104e5a3dcfc5..f45c289125ce 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -28,6 +28,7 @@ from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -40,7 +41,7 @@ INVALID_TOKEN_ID = -1 -class TPUModelRunner: +class TPUModelRunner(LoRAModelRunnerMixin): def __init__( self, @@ -410,6 +411,9 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + 1].to(self.device) seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to( self.device) + + if self.lora_config is not None: + self.set_active_loras(self.input_batch, num_scheduled_tokens_per_req) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, @@ -529,6 +533,12 @@ def load_model(self) -> None: "get_tensor_model_parallel_rank", return_value=xm_tp_rank): model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + model = self.load_lora_model(model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) model = model.eval() xm.mark_step() xm.wait_device_ops() @@ -571,12 +581,15 @@ def dummy_run( num_seqs=num_tokens, ) - torch._dynamo.mark_dynamic(input_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + if self.lora_config is not None: # TODO: Remove this condition + torch._dynamo.config.capture_dynamic_output_shape_ops = True + else: + torch._dynamo.mark_dynamic(input_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None @@ -590,7 +603,8 @@ def capture_model(self) -> None: start = time.perf_counter() num_tokens = 16 while True: - self.dummy_run(self.kv_caches, num_tokens) + with self.maybe_profile_with_lora(self.lora_config, np.array([num_tokens], dtype=np.int32)): + self.dummy_run(self.kv_caches, num_tokens) logger.info(" -- num_tokens: %d", num_tokens) xm.mark_step() xm.wait_device_ops() diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 76b6297606c3..3cefbc4b181c 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -15,6 +15,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.scheduler import SchedulerOutput @@ -171,6 +172,9 @@ def profile(self, is_start: bool = True): else: xp.stop_trace() + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + def load_model(self) -> None: self.model_runner.load_model() From b8c65bc126986280a2e50a7733ef6fb479eb9a88 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 17 Feb 2025 16:49:27 +0000 Subject: [PATCH 049/186] Added test that verifies switching Signed-off-by: Akshat Tripathi --- test_switching.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 test_switching.py diff --git a/test_switching.py b/test_switching.py new file mode 100644 index 000000000000..ad84d47d3b8e --- /dev/null +++ b/test_switching.py @@ -0,0 +1,36 @@ +import vllm + +import torch_xla.debug.profiler as xp + +from vllm.lora.request import LoRARequest + +lora_paths = ["/mnt/ssd0/adapters/1", "/mnt/ssd0/adapters/2", "/mnt/ssd0/adapters/3", "/mnt/ssd0/adapters/4"] + +lora_requests = [ + LoRARequest("lora_adapter", i+1, lora_path) + for i, lora_path in enumerate(lora_paths) +] + +llm = vllm.LLM( + model="/mnt/ssd0/work_collection/downloaded_Qwen2.5-3b-Instruct_model/", + num_scheduler_steps=1, + swap_space=16, + max_model_len=256, + max_seq_len_to_capture=256, + max_num_seqs=8, + enable_lora=True, + # enforce_eager=True, + max_loras=2, + max_lora_rank=8 +) + +for _ in range(2): + for i, req in enumerate(lora_requests): + print(i, llm.generate( + "What's 1+1?", + sampling_params=vllm.SamplingParams( + max_tokens=256, + temperature=0 + ), + lora_request=req + )) \ No newline at end of file From 942ef079dc9d27e249f6da3e612959f6f8aa5edb Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 4 Feb 2025 15:56:55 +0000 Subject: [PATCH 050/186] Added bgmv kernel test code Signed-off-by: Akshat Tripathi --- bgmv.py | 90 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 bgmv.py diff --git a/bgmv.py b/bgmv.py new file mode 100644 index 000000000000..a72448485769 --- /dev/null +++ b/bgmv.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 + +import jax +from jax import numpy as jnp +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + + +def create_tensors(T, D, L, N): + """ + Inputs: (All integers) + T: Total number of tokens + D: Input dim + L: LoRA Dim + N: N LoRAs + + Outputs: + inputs: jax.Array - shape (T, D) + lora: jax.Array - shape (L, D) + + ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T + + Ignored: + idxs: jax.Array - shape (T, ) - all values must be in [0, N) + loras: jax.Array - shape (N, 1, L, D) + """ + inputs = jax.random.normal(jax.random.PRNGKey(0), (T, D)) + lora = jax.random.normal(jax.random.PRNGKey(1), (L, D)) + ref_output = inputs @ lora.T + + return inputs, lora, ref_output + + +def bgmv_kernel(inp_ref, lora_ref, out_ref, acc_ref): + + @pl.when(pl.program_id(2) == 0) + def _(): + acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) + + acc_ref[...] += jax.lax.dot_general(inp_ref[...], + lora_ref[...], + (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) + + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + def _(): + out_ref[...] = acc_ref[...].astype(out_ref.dtype) + + +@jax.jit +def bgmv(inputs: jax.Array, lora: jax.Array): + T, D = inputs.shape + L, _ = lora.shape + + # TODO: Tune + # Also figure out how to make bT % 128 instead of bL, + # or pick block sizes based off dims + bT = 8 + bL = 128 + bD = 128 + + return pl.pallas_call( + kernel=bgmv_kernel, + out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + grid=(T // bT, L // bL, D // bD), + in_specs=[ + pl.BlockSpec((bT, bD), lambda i, j, k: (i, k)), + pl.BlockSpec((bL, bD), lambda i, j, k: (j, k)), + ], + out_specs=pl.BlockSpec((bT, bL), lambda i, j, k: (i, j)), + scratch_shapes=[pltpu.VMEM((bT, bL), jnp.float32)]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), + interpret=True)(inputs, lora) + + +if __name__ == "__main__": + T, D, L, N = 128, 3072, 128, 8 + inputs, lora, ref_output = create_tensors(T, D, L, N) + + print(lora.shape, inputs.shape, ref_output.shape) + + output1 = bgmv(inputs, lora) + + print(jnp.isnan(output1).sum(), "NaN values") + + # np.testing.assert_allclose(ref_output, output1) + # print("Success") From 56529b9a89688fb2ee5ad6fb02fc217d1127918a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Feb 2025 17:37:30 +0000 Subject: [PATCH 051/186] Added some dynamic lora selection Signed-off-by: Akshat Tripathi --- bgmv.py | 84 ++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 62 insertions(+), 22 deletions(-) diff --git a/bgmv.py b/bgmv.py index a72448485769..e33c356a46c6 100644 --- a/bgmv.py +++ b/bgmv.py @@ -16,29 +16,59 @@ def create_tensors(T, D, L, N): Outputs: inputs: jax.Array - shape (T, D) - lora: jax.Array - shape (L, D) + loras: jax.Array - shape (N, 1, L, D) + idxs: jax.Array - shape (T, ) - all values must be in [0, N) ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T - - Ignored: - idxs: jax.Array - shape (T, ) - all values must be in [0, N) - loras: jax.Array - shape (N, 1, L, D) """ inputs = jax.random.normal(jax.random.PRNGKey(0), (T, D)) - lora = jax.random.normal(jax.random.PRNGKey(1), (L, D)) - ref_output = inputs @ lora.T + loras = jax.random.normal(jax.random.PRNGKey(1), (N, 1, L, D)) + idxs = jax.random.randint(jax.random.PRNGKey(2), + shape=(T, ), + minval=0, + maxval=N) + + ref_output = jnp.einsum("td,__ld->tl", inputs, loras[idxs]) + + return inputs, loras, idxs, ref_output + + +def create_debug_tensors(T, D, L, N): + """ + Inputs: (All integers) + T: Total number of tokens + D: Input dim + L: LoRA Dim + N: N LoRAs + + Outputs: + inputs: jax.Array - shape (T, D) + loras: jax.Array - shape (N, 1, L, D) + idxs: jax.Array - shape (T, ) - all values must be in [0, N) + + ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T + """ + inputs = jnp.ones((T, D)) + loras = jnp.ones((N, 1, L, D)) * jnp.arange(0, N)[:, None, None, None] + idxs = jax.random.randint(jax.random.PRNGKey(2), + shape=(T, ), + minval=0, + maxval=N) - return inputs, lora, ref_output + ref_output = jnp.einsum("td,t_ld->tl", inputs, loras[idxs]) + return inputs, loras, idxs, ref_output -def bgmv_kernel(inp_ref, lora_ref, out_ref, acc_ref): + +def bgmv_kernel(idx_ref, inp_ref, lora_ref, out_ref, acc_ref): + del idx_ref @pl.when(pl.program_id(2) == 0) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) acc_ref[...] += jax.lax.dot_general(inp_ref[...], - lora_ref[...], + lora_ref[0, 0, ...], (((1, ), (1, )), ((), ())), preferred_element_type=jnp.float32) @@ -48,9 +78,9 @@ def _(): @jax.jit -def bgmv(inputs: jax.Array, lora: jax.Array): +def bgmv(inputs: jax.Array, lora: jax.Array, idxs: jax.Array): T, D = inputs.shape - L, _ = lora.shape + N, _, L, _ = lora.shape # TODO: Tune # Also figure out how to make bT % 128 instead of bL, @@ -63,28 +93,38 @@ def bgmv(inputs: jax.Array, lora: jax.Array): kernel=bgmv_kernel, out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=0, + num_scalar_prefetch=1, grid=(T // bT, L // bL, D // bD), in_specs=[ - pl.BlockSpec((bT, bD), lambda i, j, k: (i, k)), - pl.BlockSpec((bL, bD), lambda i, j, k: (j, k)), + pl.BlockSpec((bT, bD), lambda i, j, k, block_idx: (i, k)), + pl.BlockSpec((1, 1, bL, bD), lambda i, j, k, block_idx: + (block_idx[i * bT], 0, j, k)), ], - out_specs=pl.BlockSpec((bT, bL), lambda i, j, k: (i, j)), + out_specs=pl.BlockSpec((bT, bL), lambda i, j, k, block_idx: + (i, j)), scratch_shapes=[pltpu.VMEM((bT, bL), jnp.float32)]), compiler_params=pltpu.TPUCompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), - interpret=True)(inputs, lora) + interpret=True)(idxs, inputs, lora) if __name__ == "__main__": T, D, L, N = 128, 3072, 128, 8 - inputs, lora, ref_output = create_tensors(T, D, L, N) + inputs, lora, idxs, ref_output = create_debug_tensors(T, D, L, N) + print(idxs) + # breakpoint() print(lora.shape, inputs.shape, ref_output.shape) - output1 = bgmv(inputs, lora) + output = bgmv(inputs, lora, idxs) + + print(jnp.isnan(output).sum(), "NaN values") + + print("Err", jnp.max(jnp.abs(ref_output - output))) - print(jnp.isnan(output1).sum(), "NaN values") + output_idxs = (output / D)[:, 0] + print(output_idxs) + print(output_idxs == idxs) - # np.testing.assert_allclose(ref_output, output1) - # print("Success") + breakpoint() + # np.testing.assert_allclose(ref_output, output1, rtol=1e-2) From 735073ff85ea8da061c42c6fe812c3409ea1e509 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 7 Feb 2025 18:48:08 +0000 Subject: [PATCH 052/186] Moved and modified bgmv ops from the cpu backend to the tpu backend, because xla doesn't allow partial updates Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 1fcedfb61a93..88458ed433f8 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -52,6 +52,7 @@ def expand_slice( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, + y_total_size: int, add_inputs: bool, ) -> torch.Tensor: if self.no_lora: @@ -102,7 +103,6 @@ def add_expand(self, lora_b_stacked: Tuple[torch.Tensor, ...], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], output_slices: Tuple[int, ...], - offset_start: int = 0, add_inputs=True, **kwargs) -> torch.Tensor: """ @@ -137,6 +137,7 @@ def add_expand(self, lora_b_stacked[slice_idx], offset_left, output_slices[slice_idx], + y_total_size=sum(output_slices), add_inputs=add_inputs, ) offset_left += output_slices[slice_idx] From 1067b50a588cbbe170fe56f6a875050a6abfc54f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 10 Feb 2025 15:55:23 +0000 Subject: [PATCH 053/186] Added bgmv kernel test Signed-off-by: Akshat Tripathi --- tests/lora/tpu/__init__.py | 0 tests/lora/tpu/test_pallas_kernels.py | 58 +++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 tests/lora/tpu/__init__.py create mode 100644 tests/lora/tpu/test_pallas_kernels.py diff --git a/tests/lora/tpu/__init__.py b/tests/lora/tpu/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py new file mode 100644 index 000000000000..27be3be804e5 --- /dev/null +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from bgmv import bgmv + +N_TOKENS = [ + 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, + 131072 +] +HIDDEN_SIZES = [128, 256, 512, 896, 1024, 2048, 4096, 8192, 8320] + +DTYPES = [jnp.float16, jnp.bfloat16] +NUM_LORA = [1, 2, 4, 8, 16, 32] +RANKS = [8, 16, 32, 64, 128] + + +def generate_test_data(T, D, L, N, seed, dtype=jnp.float32): + """ + Generates debug tensors for testing. + """ + inputs = jax.random.normal(jax.random.PRNGKey(seed), (T, D)) + loras = jax.random.normal(jax.random.PRNGKey(seed), (N, 1, L, D)) + idxs = jax.random.randint(jax.random.PRNGKey(seed), + shape=(T, ), + minval=0, + maxval=N) + + ref_output = jnp.einsum("td,t_ld->tl", inputs, loras[idxs]) + return inputs, loras, idxs, ref_output + + +# Parameterize tests with various shapes and dtypes +@pytest.mark.parametrize("T", N_TOKENS) +@pytest.mark.parametrize("D", HIDDEN_SIZES) +@pytest.mark.parametrize("L", RANKS) +@pytest.mark.parametrize("N", NUM_LORA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", [0]) +def test_bgmv(T, D, L, N, dtype, op_type, seed): + inputs, loras, idxs, ref_output = generate_test_data( + T, D, L, N, seed, dtype) + + # Run bgmv + match op_type: + case "expand": + output = bgmv(inputs, loras, idxs) # TODO: Specialise + case "shrink": + output = bgmv(inputs, loras, idxs) + + # Make sure we have no NaNs + assert jnp.isnan(output).sum() == 0 + + # Compare with reference output + np.testing.assert_allclose(output, ref_output, rtol=1e-3, atol=1e-3) From d897f878ec8494d9e95177d71f4334b4eb63fc8d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 10 Feb 2025 16:59:20 +0000 Subject: [PATCH 054/186] Made bgmv kernel fully functional (WIP on supporting smaller ranks) (WIP on perf) Signed-off-by: Akshat Tripathi --- bgmv.py | 70 +++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/bgmv.py b/bgmv.py index e33c356a46c6..0959dae351a8 100644 --- a/bgmv.py +++ b/bgmv.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import functools + import jax from jax import numpy as jnp from jax.experimental import pallas as pl @@ -22,8 +24,8 @@ def create_tensors(T, D, L, N): ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T """ inputs = jax.random.normal(jax.random.PRNGKey(0), (T, D)) - loras = jax.random.normal(jax.random.PRNGKey(1), (N, 1, L, D)) - idxs = jax.random.randint(jax.random.PRNGKey(2), + loras = jax.random.normal(jax.random.PRNGKey(0), (N, 1, L, D)) + idxs = jax.random.randint(jax.random.PRNGKey(0), shape=(T, ), minval=0, maxval=N) @@ -50,7 +52,7 @@ def create_debug_tensors(T, D, L, N): """ inputs = jnp.ones((T, D)) loras = jnp.ones((N, 1, L, D)) * jnp.arange(0, N)[:, None, None, None] - idxs = jax.random.randint(jax.random.PRNGKey(2), + idxs = jax.random.randint(jax.random.PRNGKey(0), shape=(T, ), minval=0, maxval=N) @@ -60,17 +62,24 @@ def create_debug_tensors(T, D, L, N): return inputs, loras, idxs, ref_output -def bgmv_kernel(idx_ref, inp_ref, lora_ref, out_ref, acc_ref): - del idx_ref +def bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, + mask_ref): @pl.when(pl.program_id(2) == 0) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) - acc_ref[...] += jax.lax.dot_general(inp_ref[...], - lora_ref[0, 0, ...], - (((1, ), (1, )), ((), ())), - preferred_element_type=jnp.float32) + t = pl.program_id(0) + + for i in range(bT): + idx = idx_ref[i + bT * t] + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32) + + acc_ref[...] += jax.lax.dot_general( + inp_ref[...], + lora_ref[idx, 0, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_ref[...] @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) def _(): @@ -89,23 +98,30 @@ def bgmv(inputs: jax.Array, lora: jax.Array, idxs: jax.Array): bL = 128 bD = 128 - return pl.pallas_call( - kernel=bgmv_kernel, - out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - grid=(T // bT, L // bL, D // bD), - in_specs=[ - pl.BlockSpec((bT, bD), lambda i, j, k, block_idx: (i, k)), - pl.BlockSpec((1, 1, bL, bD), lambda i, j, k, block_idx: - (block_idx[i * bT], 0, j, k)), - ], - out_specs=pl.BlockSpec((bT, bL), lambda i, j, k, block_idx: - (i, j)), - scratch_shapes=[pltpu.VMEM((bT, bL), jnp.float32)]), - compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=("parallel", "parallel", "arbitrary")), - interpret=True)(idxs, inputs, lora) + return pl.pallas_call(kernel=functools.partial(bgmv_kernel, bT, bL), + out_shape=jax.ShapeDtypeStruct((T, L), + dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T // bT, L // bL, D // bD), + in_specs=[ + pl.BlockSpec((bT, bD), + lambda i, j, k, block_idx: + (i, k)), + pl.BlockSpec((N, 1, bL, bD), + lambda i, j, k, block_idx: + (0, 0, j, k)), + ], + out_specs=pl.BlockSpec( + (bT, bL), lambda i, j, k, block_idx: (i, j)), + scratch_shapes=[ + pltpu.VMEM((bT, bL), jnp.float32), + pltpu.VMEM((bT, bL), jnp.float32) + ]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", + "arbitrary")), + interpret=True)(idxs, inputs, lora) if __name__ == "__main__": @@ -126,5 +142,5 @@ def bgmv(inputs: jax.Array, lora: jax.Array, idxs: jax.Array): print(output_idxs) print(output_idxs == idxs) - breakpoint() + # breakpoint() # np.testing.assert_allclose(ref_output, output1, rtol=1e-2) From d6eca294c839f914558299c7ec3c8969e9d423b6 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 17 Feb 2025 16:48:54 +0000 Subject: [PATCH 055/186] Updated bgmv_kernel to work with ranks that aren't exact multiples of 128 Signed-off-by: Akshat Tripathi --- bgmv.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/bgmv.py b/bgmv.py index 0959dae351a8..ef2125263d4b 100644 --- a/bgmv.py +++ b/bgmv.py @@ -90,20 +90,24 @@ def _(): def bgmv(inputs: jax.Array, lora: jax.Array, idxs: jax.Array): T, D = inputs.shape N, _, L, _ = lora.shape + + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU register + L1 = L + if L < 128 or L % 128 != 0: + L1 = (L // 128 + 1) * 128 + lora = jnp.pad(lora, ((0,0), (0,0), (0,L1-L), (0,0))) - # TODO: Tune - # Also figure out how to make bT % 128 instead of bL, - # or pick block sizes based off dims + # TODO: Tune these bT = 8 bL = 128 bD = 128 return pl.pallas_call(kernel=functools.partial(bgmv_kernel, bT, bL), - out_shape=jax.ShapeDtypeStruct((T, L), + out_shape=jax.ShapeDtypeStruct((T, L1), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, - grid=(T // bT, L // bL, D // bD), + grid=(T // bT, L1 // bL, D // bD), in_specs=[ pl.BlockSpec((bT, bD), lambda i, j, k, block_idx: @@ -121,11 +125,11 @@ def bgmv(inputs: jax.Array, lora: jax.Array, idxs: jax.Array): compiler_params=pltpu.TPUCompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), - interpret=True)(idxs, inputs, lora) + interpret=True)(idxs, inputs, lora)[:, :L] if __name__ == "__main__": - T, D, L, N = 128, 3072, 128, 8 + T, D, L, N = 16, 3072, 8, 8 inputs, lora, idxs, ref_output = create_debug_tensors(T, D, L, N) print(idxs) # breakpoint() From d97aae549d2fddd3b00e140725803ee843129a82 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Feb 2025 12:00:37 +0000 Subject: [PATCH 056/186] Removed interpreted mode on kernel Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 12 ++--- vllm/lora/ops/xla_ops/pallas.py | 84 +++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 9 deletions(-) create mode 100644 vllm/lora/ops/xla_ops/pallas.py diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 483bef186185..cc541a8a8de5 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch +from .pallas import bgmv def bgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, @@ -38,17 +39,10 @@ def bgmv_shrink(inputs: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) - return scaling * outputs + return scaling * bgmv(inputs, lora_b_weights, lora_indices_tensor) def bgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py new file mode 100644 index 000000000000..2889d21e774e --- /dev/null +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -0,0 +1,84 @@ +import functools +from torch_xla.experimental.custom_kernel import jax_import_guard, make_kernel_from_pallas +jax_import_guard() + +import jax +from jax.experimental import pallas as pl +import jax.numpy as jnp +from jax.experimental.pallas import tpu as pltpu + + +def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, + mask_ref): + + @pl.when(pl.program_id(2) == 0) + def _(): + acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) + + t = pl.program_id(0) + + for i in range(bT): + idx = idx_ref[i + bT * t] + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32) + + acc_ref[...] += jax.lax.dot_general( + inp_ref[...], + lora_ref[idx, 0, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_ref[...] + + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + def _(): + out_ref[...] = acc_ref[...].astype(out_ref.dtype) + + +@jax.jit +def _bgmv(inputs: jax.Array, loras: jax.Array, idxs: jax.Array): + T, D = inputs.shape + N, _, L, _ = loras.shape + + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU register + L1 = L + if L < 128 or L % 128 != 0: + L1 = (L // 128 + 1) * 128 + loras = jnp.pad(loras, ((0,0), (0,0), (0,L1-L), (0,0))) + + # TODO: Tune these + bT = 8 + bL = 128 + bD = 128 + + return pl.pallas_call(kernel=functools.partial(_bgmv_kernel, bT, bL), + out_shape=jax.ShapeDtypeStruct((T, L1), + dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T // bT, L1 // bL, D // bD), + in_specs=[ + pl.BlockSpec((bT, bD), + lambda i, j, k, block_idx: + (i, k)), + pl.BlockSpec((N, 1, bL, bD), + lambda i, j, k, block_idx: + (0, 0, j, k)), + ], + out_specs=pl.BlockSpec( + (bT, bL), lambda i, j, k, block_idx: (i, j)), + scratch_shapes=[ + pltpu.VMEM((bT, bL), jnp.float32), + pltpu.VMEM((bT, bL), jnp.float32) + ]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", + "arbitrary")))(idxs, inputs, loras)[:, :L] + +def bgmv_shape_function(inputs, loras, idxs): + T, _ = inputs.shape + _, _, L, _ = loras.shape + + return [((T, L), inputs.dtype)] + +def bgmv(inputs, loras, idxs): + kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) + + return kernel(inputs, loras, idxs) \ No newline at end of file From 3ac0f63985a2205ff1bd2b04a613d392ee12067a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Feb 2025 13:14:28 +0000 Subject: [PATCH 057/186] Added pallas kernel benchmarking script Signed-off-by: Akshat Tripathi --- bmark_kernels.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 bmark_kernels.py diff --git a/bmark_kernels.py b/bmark_kernels.py new file mode 100644 index 000000000000..744c335dc449 --- /dev/null +++ b/bmark_kernels.py @@ -0,0 +1,47 @@ +import itertools +import pytest + +import jax +from jax import numpy as jnp +from vllm.lora.ops.xla_ops.pallas import _bgmv + +def create_tensors(T, D, L, N): + """ + Inputs: (All integers) + T: Total number of tokens + D: Input dim + L: LoRA Dim + N: N LoRAs + + Outputs: + inputs: jax.Array - shape (T, D) + loras: jax.Array - shape (N, 1, L, D) + idxs: jax.Array - shape (T, ) - all values must be in [0, N) + + ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T + """ + inputs = jax.random.normal(jax.random.PRNGKey(0), (T, D)) + loras = jax.random.normal(jax.random.PRNGKey(0), (N, 1, L, D)) + idxs = jax.random.randint(jax.random.PRNGKey(0), + shape=(T, ), + minval=0, + maxval=N) + + + return inputs, loras, idxs + +# SEQ_LENS = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] +# HIDDEN_DIM = [1024, 2048, 3072, 4096] +# LORA_RANKS = [8, 16, 32, 64, 128, 256] +# N_LORAS = [1, 2, 4, 8, 16, 32] +SEQ_LENS = [16, 8192] +HIDDEN_DIM = [1024, 4096] +LORA_RANKS = [8, 256] +N_LORAS = [1, 32] + +@pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) +def test_bgmv_benchmark(benchmark, T, D, L, N): + inputs, loras, idxs = create_tensors(T, D, L, N) + + benchmark(_bgmv, inputs, loras, idxs) + From a620e58e62803531e4004cc8bc00e441aa270f21 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Feb 2025 13:34:13 +0000 Subject: [PATCH 058/186] Fixed mosaic kernel compilation issue Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 2889d21e774e..9c91f4b0df7d 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -33,7 +33,11 @@ def _(): @jax.jit -def _bgmv(inputs: jax.Array, loras: jax.Array, idxs: jax.Array): +def _bgmv( + idxs: jax.Array, # (T, ) int32 + inputs: jax.Array, # (T, D) model dtype + loras: jax.Array # (N, 1, L, D) model dtype +) -> jax.Array: # (T, L) model dtype T, D = inputs.shape N, _, L, _ = loras.shape @@ -72,7 +76,7 @@ def _bgmv(inputs: jax.Array, loras: jax.Array, idxs: jax.Array): dimension_semantics=("parallel", "parallel", "arbitrary")))(idxs, inputs, loras)[:, :L] -def bgmv_shape_function(inputs, loras, idxs): +def bgmv_shape_function(idxs, inputs, loras): T, _ = inputs.shape _, _, L, _ = loras.shape @@ -81,4 +85,4 @@ def bgmv_shape_function(inputs, loras, idxs): def bgmv(inputs, loras, idxs): kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) - return kernel(inputs, loras, idxs) \ No newline at end of file + return kernel(idxs, inputs, loras) \ No newline at end of file From 00d6dfdb7d6583749a8e014c8a8445c75ac07d09 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Feb 2025 13:43:17 +0000 Subject: [PATCH 059/186] Added reference kernel benchmarking Signed-off-by: Akshat Tripathi --- bmark_kernels.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/bmark_kernels.py b/bmark_kernels.py index 744c335dc449..23356ac152a1 100644 --- a/bmark_kernels.py +++ b/bmark_kernels.py @@ -3,7 +3,7 @@ import jax from jax import numpy as jnp -from vllm.lora.ops.xla_ops.pallas import _bgmv +from vllm.lora.ops.xla_ops.pallas import bgmv def create_tensors(T, D, L, N): """ @@ -30,18 +30,18 @@ def create_tensors(T, D, L, N): return inputs, loras, idxs -# SEQ_LENS = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] -# HIDDEN_DIM = [1024, 2048, 3072, 4096] -# LORA_RANKS = [8, 16, 32, 64, 128, 256] -# N_LORAS = [1, 2, 4, 8, 16, 32] -SEQ_LENS = [16, 8192] -HIDDEN_DIM = [1024, 4096] -LORA_RANKS = [8, 256] -N_LORAS = [1, 32] +def ref_bgmv(inputs, loras, idxs): + return jnp.einsum("td,__ld->tl", inputs, loras[idxs]) + +SEQ_LENS = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] +HIDDEN_DIM = [1024, 2048, 3072, 4096] +LORA_RANKS = [8, 16, 32, 64, 128, 256] +N_LORAS = [1, 2, 4, 8, 16, 32] + @pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) +@pytest.mark.parametrize("func", [bgmv, ref_bgmv]) def test_bgmv_benchmark(benchmark, T, D, L, N): inputs, loras, idxs = create_tensors(T, D, L, N) - benchmark(_bgmv, inputs, loras, idxs) - + benchmark.pedantic(ref_bgmv, args=(inputs, loras, idxs), rounds=10, warmup_rounds=5, iterations=10) From fb0601d33d76271bf141cd6bc307b855f713900c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Feb 2025 14:46:40 +0000 Subject: [PATCH 060/186] Registered the custom op Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 4 ++-- vllm/lora/ops/xla_ops/pallas.py | 22 ++++++++++++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index cc541a8a8de5..8473180108fc 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch -from .pallas import bgmv +import vllm.lora.ops.xla_ops.pallas # Required to register the custom ops def bgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, @@ -42,7 +42,7 @@ def bgmv_shrink(inputs: torch.Tensor, inputs = inputs.to(dtype=output_tensor.dtype) - return scaling * bgmv(inputs, lora_b_weights, lora_indices_tensor) + return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) def bgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 9c91f4b0df7d..f7abbe5e187c 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -1,6 +1,7 @@ import functools -from torch_xla.experimental.custom_kernel import jax_import_guard, make_kernel_from_pallas -jax_import_guard() +import torch +from torch.library import impl +from torch_xla.experimental.custom_kernel import jax_import_guard, make_kernel_from_pallas, XLA_LIB import jax from jax.experimental import pallas as pl @@ -82,7 +83,20 @@ def bgmv_shape_function(idxs, inputs, loras): return [((T, L), inputs.dtype)] -def bgmv(inputs, loras, idxs): +XLA_LIB.define( + "bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", +) + +@impl(XLA_LIB, "bgmv", "XLA") +def bgmv_xla(inputs, loras, idxs): + jax_import_guard() kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) - return kernel(idxs, inputs, loras) \ No newline at end of file + return kernel(idxs, inputs, loras) + +@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") +def bgmv_non_xla(inputs, loras, idxs): + T, _ = inputs.shape + _, _, L, _ = loras.shape + + return torch.empty((T, L), device=inputs.device) From 89b062e568c8feed11f324be8bcd0715b19134c7 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Feb 2025 14:53:06 +0000 Subject: [PATCH 061/186] Integrated bgmv kernel Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 8473180108fc..aced0aa34c69 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -8,15 +8,10 @@ def bgmv_expand(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(dim=1) inputs = inputs.to(dtype=output_tensor.dtype) - # outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) - batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) + + batch_size = outputs.size(0) + outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -52,18 +47,10 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_size: int, add_inputs: bool = True): - selected_loras = lora_b_weights[lora_indices_tensor].to( - dtype=output_tensor.dtype) - inputs = inputs.to(dtype=output_tensor.dtype) - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(dim=1) - - batch_size, output_size, input_size = selected_loras.shape - - outputs = (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) + batch_size = outputs.size(0) + outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) outputs = torch.cat(( torch.zeros((batch_size, slice_offset), device=outputs.device), From ef2ef8c24f7c9435541b847bb2dd8c08fe89931b Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Feb 2025 15:31:36 +0000 Subject: [PATCH 062/186] Fixed model compilation bugs Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 4 ++-- vllm/lora/punica_wrapper/punica_tpu.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index aced0aa34c69..3ef86ea0854f 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -10,8 +10,8 @@ def bgmv_expand(inputs: torch.Tensor, add_inputs: bool = True): inputs = inputs.to(dtype=output_tensor.dtype) - batch_size = outputs.size(0) outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + batch_size = outputs.size(0) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -49,8 +49,8 @@ def bgmv_expand_slice(inputs: torch.Tensor, inputs = inputs.to(dtype=output_tensor.dtype) - batch_size = outputs.size(0) outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + batch_size = outputs.size(0) outputs = torch.cat(( torch.zeros((batch_size, slice_offset), device=outputs.device), diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 88458ed433f8..6c244913e5dd 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -22,6 +22,11 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, device: Union[torch.device, str], **kwargs): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) + + # PunicaWrapperBase defines some tensors with dtype=torch.int64, which isn't supported by the TPU. + # So convert those tensors to int32. + # Not all of them are used by the TPU so only convert the useful ones. + self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32) def shrink( self, From a79e19dc2a50214b896b27e4af5e0fab1fb6cced Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 25 Feb 2025 16:20:00 +0000 Subject: [PATCH 063/186] Minor changes Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 2 +- vllm/lora/punica_wrapper/punica_tpu.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index f7abbe5e187c..7c4716505743 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -91,7 +91,7 @@ def bgmv_shape_function(idxs, inputs, loras): def bgmv_xla(inputs, loras, idxs): jax_import_guard() kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) - + return kernel(idxs, inputs, loras) @impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 6c244913e5dd..2037b131488a 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -69,11 +69,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs) -> Optional[torch.Tensor]: """ - Performs GEMM for multiple slices of lora_a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `_shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the _shrink_decode function - should be called. + Performs GEMM for multiple slices of lora_a. Semantics: for i in range(len(lora_a_stacked)): From cc8cdf68787f8b7b3c83c00d2c02649b4ec85fdd Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 4 Mar 2025 21:21:45 +0000 Subject: [PATCH 064/186] Removed scratch files Signed-off-by: Akshat Tripathi --- bgmv.py | 150 ---------------------------------------------- bmark_kernels.py | 47 --------------- test_switching.py | 36 ----------- 3 files changed, 233 deletions(-) delete mode 100644 bgmv.py delete mode 100644 bmark_kernels.py delete mode 100644 test_switching.py diff --git a/bgmv.py b/bgmv.py deleted file mode 100644 index ef2125263d4b..000000000000 --- a/bgmv.py +++ /dev/null @@ -1,150 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import functools - -import jax -from jax import numpy as jnp -from jax.experimental import pallas as pl -from jax.experimental.pallas import tpu as pltpu - - -def create_tensors(T, D, L, N): - """ - Inputs: (All integers) - T: Total number of tokens - D: Input dim - L: LoRA Dim - N: N LoRAs - - Outputs: - inputs: jax.Array - shape (T, D) - loras: jax.Array - shape (N, 1, L, D) - idxs: jax.Array - shape (T, ) - all values must be in [0, N) - - ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T - """ - inputs = jax.random.normal(jax.random.PRNGKey(0), (T, D)) - loras = jax.random.normal(jax.random.PRNGKey(0), (N, 1, L, D)) - idxs = jax.random.randint(jax.random.PRNGKey(0), - shape=(T, ), - minval=0, - maxval=N) - - ref_output = jnp.einsum("td,__ld->tl", inputs, loras[idxs]) - - return inputs, loras, idxs, ref_output - - -def create_debug_tensors(T, D, L, N): - """ - Inputs: (All integers) - T: Total number of tokens - D: Input dim - L: LoRA Dim - N: N LoRAs - - Outputs: - inputs: jax.Array - shape (T, D) - loras: jax.Array - shape (N, 1, L, D) - idxs: jax.Array - shape (T, ) - all values must be in [0, N) - - ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T - """ - inputs = jnp.ones((T, D)) - loras = jnp.ones((N, 1, L, D)) * jnp.arange(0, N)[:, None, None, None] - idxs = jax.random.randint(jax.random.PRNGKey(0), - shape=(T, ), - minval=0, - maxval=N) - - ref_output = jnp.einsum("td,t_ld->tl", inputs, loras[idxs]) - - return inputs, loras, idxs, ref_output - - -def bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, - mask_ref): - - @pl.when(pl.program_id(2) == 0) - def _(): - acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) - - t = pl.program_id(0) - - for i in range(bT): - idx = idx_ref[i + bT * t] - mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) - mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32) - - acc_ref[...] += jax.lax.dot_general( - inp_ref[...], - lora_ref[idx, 0, ...], (((1, ), (1, )), ((), ())), - preferred_element_type=jnp.float32) * mask_ref[...] - - @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) - def _(): - out_ref[...] = acc_ref[...].astype(out_ref.dtype) - - -@jax.jit -def bgmv(inputs: jax.Array, lora: jax.Array, idxs: jax.Array): - T, D = inputs.shape - N, _, L, _ = lora.shape - - # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU register - L1 = L - if L < 128 or L % 128 != 0: - L1 = (L // 128 + 1) * 128 - lora = jnp.pad(lora, ((0,0), (0,0), (0,L1-L), (0,0))) - - # TODO: Tune these - bT = 8 - bL = 128 - bD = 128 - - return pl.pallas_call(kernel=functools.partial(bgmv_kernel, bT, bL), - out_shape=jax.ShapeDtypeStruct((T, L1), - dtype=inputs.dtype), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - grid=(T // bT, L1 // bL, D // bD), - in_specs=[ - pl.BlockSpec((bT, bD), - lambda i, j, k, block_idx: - (i, k)), - pl.BlockSpec((N, 1, bL, bD), - lambda i, j, k, block_idx: - (0, 0, j, k)), - ], - out_specs=pl.BlockSpec( - (bT, bL), lambda i, j, k, block_idx: (i, j)), - scratch_shapes=[ - pltpu.VMEM((bT, bL), jnp.float32), - pltpu.VMEM((bT, bL), jnp.float32) - ]), - compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=("parallel", "parallel", - "arbitrary")), - interpret=True)(idxs, inputs, lora)[:, :L] - - -if __name__ == "__main__": - T, D, L, N = 16, 3072, 8, 8 - inputs, lora, idxs, ref_output = create_debug_tensors(T, D, L, N) - print(idxs) - # breakpoint() - - print(lora.shape, inputs.shape, ref_output.shape) - - output = bgmv(inputs, lora, idxs) - - print(jnp.isnan(output).sum(), "NaN values") - - print("Err", jnp.max(jnp.abs(ref_output - output))) - - output_idxs = (output / D)[:, 0] - print(output_idxs) - print(output_idxs == idxs) - - # breakpoint() - # np.testing.assert_allclose(ref_output, output1, rtol=1e-2) diff --git a/bmark_kernels.py b/bmark_kernels.py deleted file mode 100644 index 23356ac152a1..000000000000 --- a/bmark_kernels.py +++ /dev/null @@ -1,47 +0,0 @@ -import itertools -import pytest - -import jax -from jax import numpy as jnp -from vllm.lora.ops.xla_ops.pallas import bgmv - -def create_tensors(T, D, L, N): - """ - Inputs: (All integers) - T: Total number of tokens - D: Input dim - L: LoRA Dim - N: N LoRAs - - Outputs: - inputs: jax.Array - shape (T, D) - loras: jax.Array - shape (N, 1, L, D) - idxs: jax.Array - shape (T, ) - all values must be in [0, N) - - ref_output: jax.Array - shape (T, L) - inputs @ loras[idxs].T - """ - inputs = jax.random.normal(jax.random.PRNGKey(0), (T, D)) - loras = jax.random.normal(jax.random.PRNGKey(0), (N, 1, L, D)) - idxs = jax.random.randint(jax.random.PRNGKey(0), - shape=(T, ), - minval=0, - maxval=N) - - - return inputs, loras, idxs - -def ref_bgmv(inputs, loras, idxs): - return jnp.einsum("td,__ld->tl", inputs, loras[idxs]) - -SEQ_LENS = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] -HIDDEN_DIM = [1024, 2048, 3072, 4096] -LORA_RANKS = [8, 16, 32, 64, 128, 256] -N_LORAS = [1, 2, 4, 8, 16, 32] - - -@pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) -@pytest.mark.parametrize("func", [bgmv, ref_bgmv]) -def test_bgmv_benchmark(benchmark, T, D, L, N): - inputs, loras, idxs = create_tensors(T, D, L, N) - - benchmark.pedantic(ref_bgmv, args=(inputs, loras, idxs), rounds=10, warmup_rounds=5, iterations=10) diff --git a/test_switching.py b/test_switching.py deleted file mode 100644 index ad84d47d3b8e..000000000000 --- a/test_switching.py +++ /dev/null @@ -1,36 +0,0 @@ -import vllm - -import torch_xla.debug.profiler as xp - -from vllm.lora.request import LoRARequest - -lora_paths = ["/mnt/ssd0/adapters/1", "/mnt/ssd0/adapters/2", "/mnt/ssd0/adapters/3", "/mnt/ssd0/adapters/4"] - -lora_requests = [ - LoRARequest("lora_adapter", i+1, lora_path) - for i, lora_path in enumerate(lora_paths) -] - -llm = vllm.LLM( - model="/mnt/ssd0/work_collection/downloaded_Qwen2.5-3b-Instruct_model/", - num_scheduler_steps=1, - swap_space=16, - max_model_len=256, - max_seq_len_to_capture=256, - max_num_seqs=8, - enable_lora=True, - # enforce_eager=True, - max_loras=2, - max_lora_rank=8 -) - -for _ in range(2): - for i, req in enumerate(lora_requests): - print(i, llm.generate( - "What's 1+1?", - sampling_params=vllm.SamplingParams( - max_tokens=256, - temperature=0 - ), - lora_request=req - )) \ No newline at end of file From ad8c56586114681f30e1f0652e58566d59cb1050 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 5 Mar 2025 15:23:28 +0000 Subject: [PATCH 065/186] Minor pallas kernel fixes Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 8 -------- vllm/lora/ops/xla_ops/pallas.py | 31 ++++++++++++++++++++++--------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 3ef86ea0854f..7916dea8772f 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -8,8 +8,6 @@ def bgmv_expand(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): - inputs = inputs.to(dtype=output_tensor.dtype) - outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) batch_size = outputs.size(0) @@ -34,9 +32,6 @@ def bgmv_shrink(inputs: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): - - inputs = inputs.to(dtype=output_tensor.dtype) - return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) def bgmv_expand_slice(inputs: torch.Tensor, @@ -46,9 +41,6 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): - - inputs = inputs.to(dtype=output_tensor.dtype) - outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) batch_size = outputs.size(0) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 7c4716505743..e9914d408fda 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -42,16 +42,26 @@ def _bgmv( T, D = inputs.shape N, _, L, _ = loras.shape - # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU register - L1 = L - if L < 128 or L % 128 != 0: - L1 = (L // 128 + 1) * 128 - loras = jnp.pad(loras, ((0,0), (0,0), (0,L1-L), (0,0))) - # TODO: Tune these bT = 8 bL = 128 bD = 128 + + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU register + L1 = L + if L < bL or L % bL != 0: + L1 = (L // bL + 1) * bL + + D1 = D + if D < bD or D % bD != 0: + D1 = (D // bD + 1) * bD + + T1 = T + if T < bT or T % bT != 0: + T1 = (T // bT + 1) * bT + + loras = jnp.pad(loras, ((0,0), (0,0), (0,L1-L), (0,D1-D))) + inputs = jnp.pad(inputs, ((0,T1-T), (0, D1-D))) return pl.pallas_call(kernel=functools.partial(_bgmv_kernel, bT, bL), out_shape=jax.ShapeDtypeStruct((T, L1), @@ -88,14 +98,17 @@ def bgmv_shape_function(idxs, inputs, loras): ) @impl(XLA_LIB, "bgmv", "XLA") -def bgmv_xla(inputs, loras, idxs): +def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): + inputs = inputs.to(dtype=loras.dtype) + jax_import_guard() kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) - + return kernel(idxs, inputs, loras) + @impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") -def bgmv_non_xla(inputs, loras, idxs): +def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): T, _ = inputs.shape _, _, L, _ = loras.shape From 8d83065cc97e871a9354c1a81fe6393ea1e482d7 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Mon, 3 Mar 2025 17:30:44 +0000 Subject: [PATCH 066/186] integrate ragged paged attn v2 Signed-off-by: Chengji Yao --- vllm/v1/attention/backends/pallas.py | 26 +++++------ vllm/v1/worker/tpu_model_runner.py | 65 +++++++++++----------------- 2 files changed, 36 insertions(+), 55 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 543e8487e28b..449f5fc3b1fd 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -41,7 +41,7 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> tuple[int, ...]: - return (num_kv_heads, num_blocks, block_size, head_size) + return (num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( @@ -131,8 +131,8 @@ def forward( query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = ([num_kv_heads, num_blocks, block_size, head_size], - [num_kv_heads, num_blocks, block_size, head_size]) + kv_cache = ([num_blocks, block_size, num_kv_heads, head_size], + [num_kv_heads, num_blocks, block_size, num_kv_heads, head_size]) attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -154,10 +154,6 @@ def forward( slot_mapping = attn_metadata.slot_mapping write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) - query = query * self.scale - # use_kernel switches between using kernel or reference implementation - # (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890). - use_kernel = False output = torch.ops.xla.ragged_paged_attention( query, key_cache, @@ -168,8 +164,9 @@ def forward( attn_metadata.num_seqs, num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, num_queries_per_block=NUM_QUERIES_PER_BLOCK, - use_kernel=use_kernel, - ) + vmem_limit_bytes=32 * 1024 * 1024, + use_kernel=True, + sm_scale=self.scale) return output.reshape(num_tokens, hidden_size) @@ -186,16 +183,15 @@ def write_to_kv_cache( Args: key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - k_cache = [num_kv_heads, num_blocks, block_size, head_size] - v_cache = [num_kv_heads, num_blocks, block_size, head_size] + k_cache = [num_blocks, block_size, num_kv_heads, head_size] + v_cache = [num_blocks, block_size, num_kv_heads, head_size] """ torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) - key = key.flatten(0, 1) - value = value.flatten(0, 1) - key_cache = key_cache.flatten(0, 2) - value_cache = value_cache.flatten(0, 2) + key_cache = key_cache.flatten(0, 1) + value_cache = value_cache.flatten(0, 1) + slot_mapping = slot_mapping.flatten() key_cache.index_copy_(0, slot_mapping, key) value_cache.index_copy_(0, slot_mapping, value) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f9a3217fbef3..24d6d058c026 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -416,8 +416,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): num_scheduled_tokens_per_req) # Do the padding and copy the tensors to the TPU. - padded_total_num_scheduled_tokens = _get_padded_number( - total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK) + padded_total_num_scheduled_tokens = _get_padded_token_len( + total_num_scheduled_tokens) self.input_ids = self.input_ids_cpu[: padded_total_num_scheduled_tokens].to( self.device) @@ -428,23 +428,22 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): slot_mapping = self.slot_mapping_cpu[: padded_total_num_scheduled_tokens].to( self.device) - padded_block_table = self.block_table_cpu[: - padded_total_num_scheduled_tokens] + padded_block_table = self.block_table_cpu[:self.max_num_reqs] padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = ( self.input_batch.block_table.get_cpu_tensor()[:num_reqs]) padded_block_table = padded_block_table.to(self.device) - query_start_loc = self.query_start_loc_cpu[: - padded_total_num_scheduled_tokens - + 1].to(self.device) - seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to( + query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( self.device) + seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=padded_block_table, context_lens=seq_lens, query_start_loc=query_start_loc, - num_seqs=num_reqs, + num_seqs=torch.tensor([num_reqs], + dtype=torch.int32, + device=self.device), ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -693,29 +692,34 @@ def _dummy_run( dtype=torch.int32, device=self.device) inputs_embeds = None + actual_num_reqs = min(num_tokens, self.max_num_reqs) position_ids = torch.zeros(num_tokens, dtype=torch.int32, device=self.device) slot_mapping = torch.zeros(num_tokens, dtype=torch.int64, device=self.device) - block_tables = torch.zeros((num_tokens, self.block_table_cpu.shape[1]), - dtype=torch.int32, - device=self.device) - query_lens = [1] * num_tokens + block_tables = torch.zeros( + (self.max_num_reqs, self.block_table_cpu.shape[1]), + dtype=torch.int32, + device=self.device) + query_lens = [1] * self.max_num_reqs query_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32).to(self.device) - context_lens = torch.ones((num_tokens, ), + context_lens = torch.ones((self.max_num_reqs, ), dtype=torch.int32, device=self.device) + num_seqs = torch.tensor([actual_num_reqs], + dtype=torch.int32, + device=self.device) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, context_lens=context_lens, query_start_loc=query_start_loc, - num_seqs=num_tokens, + num_seqs=num_seqs, ) if self.is_multimodal_model: @@ -724,9 +728,6 @@ def _dummy_run( torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) with set_forward_context(attn_metadata, self.vllm_config, 0): assert self.model is not None @@ -817,28 +818,6 @@ def forward( inputs_embeds: The input embeddings of shape [num_tokens, hidden_size]. It is used for multimodal models. """ - # Skip this in memory profiling at initialization. - if kv_caches[0][0].numel() > 0: - attn_metadata = get_forward_context().attn_metadata - # index_copy_(slot_mapping) only works when the inserted dimension - # is 0. However, the KV cache in the Pallas backend has the shape - # [num_kv_heads, num_blocks, block_size, head_size]. To make it - # work, we need to flatten the first three dimensions and modify - # the slot_mapping accordingly. - # kv_caches: list[tuple[torch.Tensor, torch.Tensor]] - num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape - slot_mapping = attn_metadata.slot_mapping - slot_mapping = slot_mapping.flatten() - head_indicies = torch.arange(0, - num_kv_heads, - device=slot_mapping.device, - dtype=slot_mapping.dtype) - head_indicies *= block_size * num_blocks - slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( - -1, num_kv_heads) - slot_mapping = slot_mapping + head_indicies.view(1, -1) - slot_mapping = slot_mapping.flatten() - attn_metadata.slot_mapping = slot_mapping assert self.model is not None hidden_states = self.model( @@ -866,3 +845,9 @@ def get_input_embeddings(self, *args, **kwargs): def _get_padded_number(n: int, multiple: int) -> int: return ((n + multiple - 1) // multiple) * multiple + + +def _get_padded_token_len(x: int) -> int: + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() From dea7d028696973e8918838115ece89b85c049eee Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Wed, 5 Mar 2025 21:44:33 +0000 Subject: [PATCH 067/186] fix precompile Signed-off-by: Chengji Yao --- vllm/v1/attention/backends/pallas.py | 2 +- vllm/v1/worker/tpu_model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 449f5fc3b1fd..a1b906c3ccd2 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -132,7 +132,7 @@ def forward( key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = ([num_blocks, block_size, num_kv_heads, head_size], - [num_kv_heads, num_blocks, block_size, num_kv_heads, head_size]) + [num_blocks, block_size, num_kv_heads, head_size]) attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 24d6d058c026..7e6912e48765 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -14,7 +14,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import VllmConfig -from vllm.forward_context import get_forward_context, set_forward_context +from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model From 62493074d869911f1a7729660976d273269986af Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Mar 2025 11:51:37 +0000 Subject: [PATCH 068/186] Fixed padding issue with v1 Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0728bfedc7f6..a906ec5f0dcc 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -438,7 +438,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) if self.lora_config is not None: - self.set_active_loras(self.input_batch, num_scheduled_tokens_per_req) + self.set_active_loras(self.input_batch, np.array(padded_total_num_scheduled_tokens)) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, From af0a6a97ff49a3666750c6e939bdcdf02ae193b8 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Mar 2025 13:25:23 +0000 Subject: [PATCH 069/186] Added temporary patch over pallas kernel routing bug Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 76 +++++++++++++++++--------- vllm/lora/punica_wrapper/punica_tpu.py | 6 +- 2 files changed, 53 insertions(+), 29 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index e9914d408fda..8e1ef86de745 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -1,6 +1,7 @@ import functools import torch from torch.library import impl +import torch_xla from torch_xla.experimental.custom_kernel import jax_import_guard, make_kernel_from_pallas, XLA_LIB import jax @@ -25,7 +26,7 @@ def _(): acc_ref[...] += jax.lax.dot_general( inp_ref[...], - lora_ref[idx, 0, ...], (((1, ), (1, )), ((), ())), + lora_ref[idx, ...], (((1, ), (1, )), ((), ())), preferred_element_type=jnp.float32) * mask_ref[...] @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) @@ -40,7 +41,7 @@ def _bgmv( loras: jax.Array # (N, 1, L, D) model dtype ) -> jax.Array: # (T, L) model dtype T, D = inputs.shape - N, _, L, _ = loras.shape + N, L, _ = loras.shape # TODO: Tune these bT = 8 @@ -60,36 +61,37 @@ def _bgmv( if T < bT or T % bT != 0: T1 = (T // bT + 1) * bT - loras = jnp.pad(loras, ((0,0), (0,0), (0,L1-L), (0,D1-D))) + loras = jnp.pad(loras, ((0,0), (0,L1-L), (0,D1-D))) inputs = jnp.pad(inputs, ((0,T1-T), (0, D1-D))) return pl.pallas_call(kernel=functools.partial(_bgmv_kernel, bT, bL), - out_shape=jax.ShapeDtypeStruct((T, L1), - dtype=inputs.dtype), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - grid=(T // bT, L1 // bL, D // bD), - in_specs=[ - pl.BlockSpec((bT, bD), - lambda i, j, k, block_idx: - (i, k)), - pl.BlockSpec((N, 1, bL, bD), - lambda i, j, k, block_idx: - (0, 0, j, k)), - ], - out_specs=pl.BlockSpec( - (bT, bL), lambda i, j, k, block_idx: (i, j)), - scratch_shapes=[ - pltpu.VMEM((bT, bL), jnp.float32), - pltpu.VMEM((bT, bL), jnp.float32) - ]), - compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=("parallel", "parallel", - "arbitrary")))(idxs, inputs, loras)[:, :L] + out_shape=jax.ShapeDtypeStruct((T1, L1), + dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T1 // bT, L1 // bL, D1 // bD), + in_specs=[ + pl.BlockSpec((bT, bD), + lambda i, j, k, block_idx: + (i, k)), + pl.BlockSpec((N, bL, bD), + lambda i, j, k, block_idx: + (0, j, k)), + ], + out_specs=pl.BlockSpec( + (bT, bL), lambda i, j, k, block_idx: (i, j)), + scratch_shapes=[ + pltpu.VMEM((bT, bL), jnp.float32), + pltpu.VMEM((bT, bL), jnp.float32) + ]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), + name="bgmv" + )(idxs, inputs, loras)[:T, :L] def bgmv_shape_function(idxs, inputs, loras): T, _ = inputs.shape - _, _, L, _ = loras.shape + _, L, _ = loras.shape return [((T, L), inputs.dtype)] @@ -97,10 +99,32 @@ def bgmv_shape_function(idxs, inputs, loras): "bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", ) +def ref_bgmv(inputs: jax.Array, loras: jax.Array, idxs: jax.Array): + selected_loras = loras[idxs] + n_tokens, output_size, input_size = selected_loras.shape + outputs = ( + selected_loras @ inputs.reshape((n_tokens, input_size, 1)) + ).reshape((n_tokens, output_size)) + + return outputs + @impl(XLA_LIB, "bgmv", "XLA") def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): inputs = inputs.to(dtype=loras.dtype) + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + _, L, D = loras.shape + + # FIXME: Routing the output from 1 Pallas kernel directly to another results in NaN outputs + # so here we fallback on a reference implementation until the bug is fixed + use_reference_on_shrink = True + if use_reference_on_shrink and L < D: + return ref_bgmv(inputs, loras, idxs) + elif not use_reference_on_shrink and D < L: + return ref_bgmv(inputs, loras, idxs) + jax_import_guard() kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 2037b131488a..a3d1b482624e 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -191,7 +191,7 @@ def add_lora_linear(self, Args: y (torch.Tensor): Output tensor. Will be changed in-place. - x (torch.Tensor): Input tensor (B, S, E) + x (torch.Tensor): Input tensor (T, E) lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. @@ -210,9 +210,9 @@ def add_lora_linear(self, r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default, consistent with the # triton op - batch_size, seq_len, _ = x.shape + T = x.size(0) buffer = torch.zeros( - (len(output_slices), batch_size * seq_len, r), + (len(output_slices), T, r), dtype=torch.float32, device=x.device, ) From 264d36a07552e729bb372856c6de9ae59e31d39d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Mar 2025 13:31:27 +0000 Subject: [PATCH 070/186] Updated kernel test Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 50 ++++++++++++++++++--------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index 27be3be804e5..89423463dee0 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -1,10 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -import jax -import jax.numpy as jnp -import numpy as np import pytest -from bgmv import bgmv +import torch +import vllm.lora.ops.xla_ops.pallas # Required to register the custom ops N_TOKENS = [ 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, @@ -19,18 +17,36 @@ def generate_test_data(T, D, L, N, seed, dtype=jnp.float32): """ - Generates debug tensors for testing. + Inputs: (All integers) + T: Total number of tokens + D: Input dim + L: LoRA Dim + N: N LoRAs + + Outputs: + inputs: torch.Tensor - shape (T, D) + loras: torch.Tensor - shape (N, 1, L, D) + idxs: torch.Tensor - shape (T, ) - all values must be in [0, N) + + ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T """ - inputs = jax.random.normal(jax.random.PRNGKey(seed), (T, D)) - loras = jax.random.normal(jax.random.PRNGKey(seed), (N, 1, L, D)) - idxs = jax.random.randint(jax.random.PRNGKey(seed), - shape=(T, ), - minval=0, - maxval=N) - - ref_output = jnp.einsum("td,t_ld->tl", inputs, loras[idxs]) + + inputs = torch.randn((T, D), device="xla") + loras = torch.randn((N, 1, L, D), device="xla") + idxs = torch.randint(0, N, (T,), dtype=torch.int32, device="xla") + + ref_output = ref_bgmv(inputs, loras, idxs) return inputs, loras, idxs, ref_output +def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor): + selected_loras = loras[idxs] + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(axis=1) + + batch_size, output_size, input_size = selected_loras.shape + outputs = ( + selected_loras @ inputs.reshape((batch_size, input_size, 1)) + ).reshape((batch_size, output_size)) # Parameterize tests with various shapes and dtypes @pytest.mark.parametrize("T", N_TOKENS) @@ -47,12 +63,12 @@ def test_bgmv(T, D, L, N, dtype, op_type, seed): # Run bgmv match op_type: case "expand": - output = bgmv(inputs, loras, idxs) # TODO: Specialise + output = torch.ops.xla.bgmv(inputs, loras, idxs) # TODO: Specialise case "shrink": - output = bgmv(inputs, loras, idxs) + output = torch.ops.xla.bgmv(inputs, loras, idxs) # Make sure we have no NaNs - assert jnp.isnan(output).sum() == 0 + assert not torch.any(torch.isnan(output)) # Compare with reference output - np.testing.assert_allclose(output, ref_output, rtol=1e-3, atol=1e-3) + assert torch.allclose(output, ref_output, rtol=1e-3, atol=1e-3) From b725c6a93d694322ec021fec1538bde1b22e4e11 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Mar 2025 13:38:24 +0000 Subject: [PATCH 071/186] Lint Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 25 ++--- tests/tpu/test_lora.py | 28 +++--- vllm/lora/ops/xla_ops/lora_ops.py | 10 +- vllm/lora/ops/xla_ops/pallas.py | 130 +++++++++++++------------ vllm/lora/punica_wrapper/punica_tpu.py | 7 +- vllm/lora/punica_wrapper/utils.py | 6 +- vllm/v1/worker/tpu_model_runner.py | 14 +-- 7 files changed, 118 insertions(+), 102 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index 89423463dee0..69df033b0705 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 import pytest - import torch -import vllm.lora.ops.xla_ops.pallas # Required to register the custom ops + +# Required to register the custom ops +import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import N_TOKENS = [ 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, @@ -10,12 +11,12 @@ ] HIDDEN_SIZES = [128, 256, 512, 896, 1024, 2048, 4096, 8192, 8320] -DTYPES = [jnp.float16, jnp.bfloat16] +DTYPES = [torch.float16, torch.bfloat16] NUM_LORA = [1, 2, 4, 8, 16, 32] RANKS = [8, 16, 32, 64, 128] -def generate_test_data(T, D, L, N, seed, dtype=jnp.float32): +def generate_test_data(T, D, L, N, seed, dtype=torch.float32): """ Inputs: (All integers) T: Total number of tokens @@ -30,23 +31,24 @@ def generate_test_data(T, D, L, N, seed, dtype=jnp.float32): ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T """ - + inputs = torch.randn((T, D), device="xla") loras = torch.randn((N, 1, L, D), device="xla") - idxs = torch.randint(0, N, (T,), dtype=torch.int32, device="xla") + idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla") ref_output = ref_bgmv(inputs, loras, idxs) return inputs, loras, idxs, ref_output + def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor): selected_loras = loras[idxs] if len(selected_loras.shape) == 4: selected_loras = selected_loras.squeeze(axis=1) - + batch_size, output_size, input_size = selected_loras.shape - outputs = ( - selected_loras @ inputs.reshape((batch_size, input_size, 1)) - ).reshape((batch_size, output_size)) + outputs = (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) + # Parameterize tests with various shapes and dtypes @pytest.mark.parametrize("T", N_TOKENS) @@ -63,7 +65,8 @@ def test_bgmv(T, D, L, N, dtype, op_type, seed): # Run bgmv match op_type: case "expand": - output = torch.ops.xla.bgmv(inputs, loras, idxs) # TODO: Specialise + output = torch.ops.xla.bgmv(inputs, loras, + idxs) # TODO: Specialise case "shrink": output = torch.ops.xla.bgmv(inputs, loras, idxs) diff --git a/tests/tpu/test_lora.py b/tests/tpu/test_lora.py index d3d2c7eb2e1d..105cc672a9ba 100644 --- a/tests/tpu/test_lora.py +++ b/tests/tpu/test_lora.py @@ -1,7 +1,8 @@ +# SPDX-License-Identifier: Apache-2.0 import vllm - from vllm.lora.request import LoRARequest + def test_lora_hotswapping(): lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" lora_requests = [ @@ -9,20 +10,21 @@ def test_lora_hotswapping(): for i in range(1, 5) ] - llm = vllm.LLM( - model="Qwen/Qwen2.5-3B-Instruct", - num_scheduler_steps=1, - max_model_len=256, - max_seq_len_to_capture=256, - max_num_seqs=8, - enable_lora=True, - max_loras=2, - max_lora_rank=8 - ) + llm = vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", + num_scheduler_steps=1, + max_model_len=256, + max_seq_len_to_capture=256, + max_num_seqs=8, + enable_lora=True, + max_loras=2, + max_lora_rank=8) prompt = "What is 1+1? \n" for _ in range(10): for i, req in enumerate(lora_requests): - output = llm.generate(prompt, sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), lora_request=req)[0].outputs[0].text - assert int(output.strip()[0]) == i + 1 \ No newline at end of file + output = llm.generate(prompt, + sampling_params=vllm.SamplingParams( + max_tokens=256, temperature=0), + lora_request=req)[0].outputs[0].text + assert int(output.strip()[0]) == i + 1 diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 7916dea8772f..5f051575d3fc 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import vllm.lora.ops.xla_ops.pallas # Required to register the custom ops + +# Required to register the custom ops +import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import + def bgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, @@ -26,13 +29,16 @@ def bgmv_expand(inputs: torch.Tensor, else: return outputs[:limit, :] + def bgmv_shrink(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): - return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, + lora_indices_tensor) + def bgmv_expand_slice(inputs: torch.Tensor, lora_b_weights: torch.Tensor, diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 8e1ef86de745..92288a0edd3b 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -1,17 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 import functools -import torch -from torch.library import impl -import torch_xla -from torch_xla.experimental.custom_kernel import jax_import_guard, make_kernel_from_pallas, XLA_LIB import jax -from jax.experimental import pallas as pl import jax.numpy as jnp +import torch +from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu +from torch.library import impl +from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, + make_kernel_from_pallas) -def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, - mask_ref): +def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, + acc_ref, mask_ref): @pl.when(pl.program_id(2) == 0) def _(): @@ -36,104 +37,105 @@ def _(): @jax.jit def _bgmv( - idxs: jax.Array, # (T, ) int32 - inputs: jax.Array, # (T, D) model dtype - loras: jax.Array # (N, 1, L, D) model dtype -) -> jax.Array: # (T, L) model dtype + idxs: jax.Array, # (T, ) int32 + inputs: jax.Array, # (T, D) model dtype + loras: jax.Array # (N, 1, L, D) model dtype +) -> jax.Array: # (T, L) model dtype T, D = inputs.shape N, L, _ = loras.shape - + # TODO: Tune these bT = 8 bL = 128 bD = 128 - + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU register L1 = L - if L < bL or L % bL != 0: + if bL > L or L % bL != 0: L1 = (L // bL + 1) * bL - + D1 = D - if D < bD or D % bD != 0: + if bD > D or D % bD != 0: D1 = (D // bD + 1) * bD - + T1 = T - if T < bT or T % bT != 0: + if bT > T or T % bT != 0: T1 = (T // bT + 1) * bT - - loras = jnp.pad(loras, ((0,0), (0,L1-L), (0,D1-D))) - inputs = jnp.pad(inputs, ((0,T1-T), (0, D1-D))) + + loras = jnp.pad(loras, ((0, 0), (0, L1 - L), (0, D1 - D))) + inputs = jnp.pad(inputs, ((0, T1 - T), (0, D1 - D))) return pl.pallas_call(kernel=functools.partial(_bgmv_kernel, bT, bL), - out_shape=jax.ShapeDtypeStruct((T1, L1), - dtype=inputs.dtype), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - grid=(T1 // bT, L1 // bL, D1 // bD), - in_specs=[ - pl.BlockSpec((bT, bD), - lambda i, j, k, block_idx: - (i, k)), - pl.BlockSpec((N, bL, bD), - lambda i, j, k, block_idx: - (0, j, k)), - ], - out_specs=pl.BlockSpec( - (bT, bL), lambda i, j, k, block_idx: (i, j)), - scratch_shapes=[ - pltpu.VMEM((bT, bL), jnp.float32), - pltpu.VMEM((bT, bL), jnp.float32) - ]), - compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=("parallel", "parallel", "arbitrary")), - name="bgmv" - )(idxs, inputs, loras)[:T, :L] + out_shape=jax.ShapeDtypeStruct((T1, L1), + dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T1 // bT, L1 // bL, D1 // bD), + in_specs=[ + pl.BlockSpec((bT, bD), + lambda i, j, k, block_idx: + (i, k)), + pl.BlockSpec((N, bL, bD), + lambda i, j, k, block_idx: + (0, j, k)), + ], + out_specs=pl.BlockSpec( + (bT, bL), lambda i, j, k, block_idx: (i, j)), + scratch_shapes=[ + pltpu.VMEM((bT, bL), jnp.float32), + pltpu.VMEM((bT, bL), jnp.float32) + ]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", + "arbitrary")), + name="bgmv")(idxs, inputs, loras)[:T, :L] + def bgmv_shape_function(idxs, inputs, loras): T, _ = inputs.shape _, L, _ = loras.shape - + return [((T, L), inputs.dtype)] -XLA_LIB.define( - "bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", -) + +XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", ) + def ref_bgmv(inputs: jax.Array, loras: jax.Array, idxs: jax.Array): - selected_loras = loras[idxs] + selected_loras = loras[idxs] n_tokens, output_size, input_size = selected_loras.shape - outputs = ( - selected_loras @ inputs.reshape((n_tokens, input_size, 1)) - ).reshape((n_tokens, output_size)) - + outputs = (selected_loras @ inputs.reshape( + (n_tokens, input_size, 1))).reshape((n_tokens, output_size)) + return outputs + @impl(XLA_LIB, "bgmv", "XLA") def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): inputs = inputs.to(dtype=loras.dtype) - + if len(loras.shape) == 4: loras = loras.squeeze(axis=1) - + _, L, D = loras.shape - + # FIXME: Routing the output from 1 Pallas kernel directly to another results in NaN outputs - # so here we fallback on a reference implementation until the bug is fixed + # so here we fallback on a reference implementation until the bug is fixed. The kernel can + # be used for either shrink or expand, but not both at the same time. use_reference_on_shrink = True - if use_reference_on_shrink and L < D: - return ref_bgmv(inputs, loras, idxs) - elif not use_reference_on_shrink and D < L: + if use_reference_on_shrink and L < D or not use_reference_on_shrink and D < L: return ref_bgmv(inputs, loras, idxs) - + jax_import_guard() kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) - + return kernel(idxs, inputs, loras) @impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") -def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): +def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): T, _ = inputs.shape _, _, L, _ = loras.shape - + return torch.empty((T, L), device=inputs.device) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index a3d1b482624e..0d009d9ed8c9 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -22,11 +22,12 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, device: Union[torch.device, str], **kwargs): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - + # PunicaWrapperBase defines some tensors with dtype=torch.int64, which isn't supported by the TPU. - # So convert those tensors to int32. + # So convert those tensors to int32. # Not all of them are used by the TPU so only convert the useful ones. - self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32) + self._token_lora_indices = self._token_lora_indices.to( + dtype=torch.int32) def shrink( self, diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index 00c3689ef462..bb9c606cda98 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -125,11 +125,13 @@ def convert_mapping( indices[2] * extra_vocab_size, indices[2] * (vocab_size + extra_vocab_size), ]) - embeddings_indices = torch.where(embeddings_indices == -1, embeddings_indices, max_loras - 1) + embeddings_indices = torch.where(embeddings_indices == -1, + embeddings_indices, max_loras - 1) base_indices = indices[1] sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded = torch.where(sampler_indices_padded == -1, sampler_indices_padded, max_loras - 1) + sampler_indices_padded = torch.where(sampler_indices_padded == -1, + sampler_indices_padded, max_loras - 1) sampler_indices_padded = torch.arange( 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( sampler_indices_padded * len(sampler_indices_padded)) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index a906ec5f0dcc..0d360646ad9d 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -436,9 +436,10 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to( self.device) seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) - + if self.lora_config is not None: - self.set_active_loras(self.input_batch, np.array(padded_total_num_scheduled_tokens)) + self.set_active_loras(self.input_batch, + np.array(padded_total_num_scheduled_tokens)) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, @@ -673,11 +674,9 @@ def load_model(self) -> None: return_value=xm_tp_rank): model = get_model(vllm_config=self.vllm_config) if self.lora_config: - model = self.load_lora_model(model, - self.model_config, + model = self.load_lora_model(model, self.model_config, self.scheduler_config, - self.lora_config, - self.device) + self.lora_config, self.device) model = model.eval() xm.mark_step() xm.wait_device_ops() @@ -756,7 +755,8 @@ def capture_model(self) -> None: start = time.perf_counter() num_tokens = 16 while True: - with self.maybe_profile_with_lora(self.lora_config, np.array([num_tokens], dtype=np.int32)): + with self.maybe_profile_with_lora( + self.lora_config, np.array([num_tokens], dtype=np.int32)): self._dummy_run(self.kv_caches, num_tokens) logger.info(" -- num_tokens: %d", num_tokens) xm.mark_step() From 038465c5e3901e69340d8d1df435f2008bca2907 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Mar 2025 14:24:26 +0000 Subject: [PATCH 072/186] Removed duplicate method Signed-off-by: Akshat Tripathi --- vllm/platforms/tpu.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 0c9d247d4a5d..339d35651e14 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -57,11 +57,6 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return not envs.VLLM_USE_V1 - @classmethod - def is_pin_memory_available(cls): - logger.warning("Pin memory is not supported on TPU.") - return False - @classmethod def get_punica_wrapper(cls) -> str: return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" From 20043696e5e95fe2b50fdec6b0e9dd79161e2bfc Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Mar 2025 14:32:45 +0000 Subject: [PATCH 073/186] Lint Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 10 ++++------ vllm/lora/punica_wrapper/punica_tpu.py | 1 + 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index 69df033b0705..bcb7bb877ffd 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -63,12 +63,10 @@ def test_bgmv(T, D, L, N, dtype, op_type, seed): T, D, L, N, seed, dtype) # Run bgmv - match op_type: - case "expand": - output = torch.ops.xla.bgmv(inputs, loras, - idxs) # TODO: Specialise - case "shrink": - output = torch.ops.xla.bgmv(inputs, loras, idxs) + if op_type == "expand": + output = torch.ops.xla.bgmv(inputs, loras, idxs) # TODO: Specialise + else: + output = torch.ops.xla.bgmv(inputs, loras, idxs) # Make sure we have no NaNs assert not torch.any(torch.isnan(output)) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 0d009d9ed8c9..f118e011d599 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -105,6 +105,7 @@ def add_expand(self, lora_b_stacked: Tuple[torch.Tensor, ...], lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], output_slices: Tuple[int, ...], + offset_start: int = 0, add_inputs=True, **kwargs) -> torch.Tensor: """ From 71a1cdd28e7be3a74e3cb704f91f10ce248955d8 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Mar 2025 14:36:19 +0000 Subject: [PATCH 074/186] More linting Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 2 +- tests/tpu/test_lora.py | 3 ++- vllm/lora/ops/xla_ops/pallas.py | 14 ++++++++------ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index bcb7bb877ffd..6ebaf6de493b 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -46,7 +46,7 @@ def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor): selected_loras = selected_loras.squeeze(axis=1) batch_size, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape( + return (selected_loras @ inputs.reshape( (batch_size, input_size, 1))).reshape((batch_size, output_size)) diff --git a/tests/tpu/test_lora.py b/tests/tpu/test_lora.py index 105cc672a9ba..2fafd9b1fc2d 100644 --- a/tests/tpu/test_lora.py +++ b/tests/tpu/test_lora.py @@ -4,7 +4,8 @@ def test_lora_hotswapping(): - lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" + lora_name_template = \ + "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" lora_requests = [ LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) for i in range(1, 5) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 92288a0edd3b..c95bfc7c8e09 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -49,7 +49,8 @@ def _bgmv( bL = 128 bD = 128 - # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU register + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU + # register L1 = L if bL > L or L % bL != 0: L1 = (L // bL + 1) * bL @@ -119,11 +120,12 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): _, L, D = loras.shape - # FIXME: Routing the output from 1 Pallas kernel directly to another results in NaN outputs - # so here we fallback on a reference implementation until the bug is fixed. The kernel can - # be used for either shrink or expand, but not both at the same time. - use_reference_on_shrink = True - if use_reference_on_shrink and L < D or not use_reference_on_shrink and D < L: + # FIXME: Routing the output from 1 Pallas kernel directly to another results + # in NaN outputs so here we fallback on a reference implementation until the + # bug is fixed. The kernel can be used for either shrink or expand, but not + # both at the same time. + use_reference = True + if use_reference and L < D or not use_reference and D < L: return ref_bgmv(inputs, loras, idxs) jax_import_guard() From 3dba9e0bde2e57cf6dcd268ed6d26b7450bc15b2 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Mar 2025 14:42:37 +0000 Subject: [PATCH 075/186] Linting Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 5 +++-- vllm/lora/ops/xla_ops/pallas.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index 6ebaf6de493b..dbe8fcbfcb9b 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -31,9 +31,10 @@ def generate_test_data(T, D, L, N, seed, dtype=torch.float32): ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T """ + torch.manual_seed(seed) - inputs = torch.randn((T, D), device="xla") - loras = torch.randn((N, 1, L, D), device="xla") + inputs = torch.randn((T, D), device="xla", dtype=dtype) + loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype) idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla") ref_output = ref_bgmv(inputs, loras, idxs) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index c95bfc7c8e09..122b6dd43064 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -49,7 +49,7 @@ def _bgmv( bL = 128 bD = 128 - # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU # register L1 = L if bL > L or L % bL != 0: @@ -121,8 +121,8 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): _, L, D = loras.shape # FIXME: Routing the output from 1 Pallas kernel directly to another results - # in NaN outputs so here we fallback on a reference implementation until the - # bug is fixed. The kernel can be used for either shrink or expand, but not + # in NaN outputs so here we fallback on a reference implementation until the + # bug is fixed. The kernel can be used for either shrink or expand, but not # both at the same time. use_reference = True if use_reference and L < D or not use_reference and D < L: From f7f95e453157b4cc404570911fe45480fcd98661 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Mar 2025 14:47:42 +0000 Subject: [PATCH 076/186] Lint Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index f118e011d599..406b557827ba 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -13,8 +13,8 @@ # inherit this class class PunicaWrapperTPU(PunicaWrapperBase): """ - PunicaWrapperTPU is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + PunicaWrapperTPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the pytorch punica ops. """ @@ -23,8 +23,8 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - # PunicaWrapperBase defines some tensors with dtype=torch.int64, which isn't supported by the TPU. - # So convert those tensors to int32. + # PunicaWrapperBase defines some tensors with dtype=torch.int64, which + # isn't supported by the TPU. So convert those tensors to int32. # Not all of them are used by the TPU so only convert the useful ones. self._token_lora_indices = self._token_lora_indices.to( dtype=torch.int32) @@ -71,11 +71,11 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], scale: float, **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. - + Semantics: for i in range(len(lora_a_stacked)): y[i] += (x @ lora_a_stacked[i]) * scale - + Args: y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors x (torch.Tensor): Input tensor @@ -110,19 +110,19 @@ def add_expand(self, **kwargs) -> torch.Tensor: """ Performs GEMM and bias addition for multiple slices of lora_b. - + Semantics: for i in range(len(lora_b_stacked)): slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + - lora_bias_stacked[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] offset += slice - + Args: y (torch.Tensor): Output tensor. x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): bias's weight output_slices (Tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. @@ -180,7 +180,7 @@ def add_lora_linear(self, buffer: Optional[Tuple[torch.Tensor, ...]] = None, **kwargs) -> torch.Tensor: """ - Applicable to linear-related lora. + Applicable to linear-related lora. Semantics: for i in range(len(lora_a_stacked)): @@ -238,7 +238,7 @@ def add_lora_logits(self, **kwargs) -> torch.Tensor: """ Applies lora specifically for LogitsProcessorWithLoRA. - + Semantics: buffer = (x @ lora_a_stacked) * scale y += buffer @ lora_b_stacked From adfdcdbbc7b6f48977af680c88bb0bdcbfe7931f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 6 Mar 2025 18:23:01 +0000 Subject: [PATCH 077/186] Fixed bug related to consecutive pallas kernels Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 120 +++++++++++-------------- vllm/lora/punica_wrapper/punica_tpu.py | 4 +- 2 files changed, 53 insertions(+), 71 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 122b6dd43064..e7174b6c1e4e 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -10,6 +10,11 @@ from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, make_kernel_from_pallas) +# TODO: Tune these +TOKENS_BLOCK = 16 +LORA_RANK_BLOCK = 128 +DIM_BLOCK_SIZE = 128 + def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, mask_ref): @@ -37,59 +42,35 @@ def _(): @jax.jit def _bgmv( - idxs: jax.Array, # (T, ) int32 - inputs: jax.Array, # (T, D) model dtype - loras: jax.Array # (N, 1, L, D) model dtype -) -> jax.Array: # (T, L) model dtype + idxs: jax.Array, # (T, ) int32 + inputs: jax.Array, # (T, D) model dtype + loras: jax.Array # (N, L, D) model dtype +) -> jax.Array: # (T, L) model dtype T, D = inputs.shape N, L, _ = loras.shape - # TODO: Tune these - bT = 8 - bL = 128 - bD = 128 - - # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU - # register - L1 = L - if bL > L or L % bL != 0: - L1 = (L // bL + 1) * bL - - D1 = D - if bD > D or D % bD != 0: - D1 = (D // bD + 1) * bD - - T1 = T - if bT > T or T % bT != 0: - T1 = (T // bT + 1) * bT - - loras = jnp.pad(loras, ((0, 0), (0, L1 - L), (0, D1 - D))) - inputs = jnp.pad(inputs, ((0, T1 - T), (0, D1 - D))) - - return pl.pallas_call(kernel=functools.partial(_bgmv_kernel, bT, bL), - out_shape=jax.ShapeDtypeStruct((T1, L1), - dtype=inputs.dtype), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - grid=(T1 // bT, L1 // bL, D1 // bD), - in_specs=[ - pl.BlockSpec((bT, bD), - lambda i, j, k, block_idx: - (i, k)), - pl.BlockSpec((N, bL, bD), - lambda i, j, k, block_idx: - (0, j, k)), - ], - out_specs=pl.BlockSpec( - (bT, bL), lambda i, j, k, block_idx: (i, j)), - scratch_shapes=[ - pltpu.VMEM((bT, bL), jnp.float32), - pltpu.VMEM((bT, bL), jnp.float32) - ]), - compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=("parallel", "parallel", - "arbitrary")), - name="bgmv")(idxs, inputs, loras)[:T, :L] + return pl.pallas_call( + kernel=functools.partial(_bgmv_kernel, TOKENS_BLOCK, LORA_RANK_BLOCK), + out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T // TOKENS_BLOCK, L // LORA_RANK_BLOCK, + D // DIM_BLOCK_SIZE), + in_specs=[ + pl.BlockSpec((TOKENS_BLOCK, DIM_BLOCK_SIZE), + lambda i, j, k, block_idx: (i, k)), + pl.BlockSpec((N, LORA_RANK_BLOCK, DIM_BLOCK_SIZE), + lambda i, j, k, block_idx: (0, j, k)), + ], + out_specs=pl.BlockSpec((TOKENS_BLOCK, LORA_RANK_BLOCK), + lambda i, j, k, block_idx: (i, j)), + scratch_shapes=[ + pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32), + pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32) + ]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), + name="bgmv")(idxs, inputs, loras) def bgmv_shape_function(idxs, inputs, loras): @@ -102,15 +83,6 @@ def bgmv_shape_function(idxs, inputs, loras): XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", ) -def ref_bgmv(inputs: jax.Array, loras: jax.Array, idxs: jax.Array): - selected_loras = loras[idxs] - n_tokens, output_size, input_size = selected_loras.shape - outputs = (selected_loras @ inputs.reshape( - (n_tokens, input_size, 1))).reshape((n_tokens, output_size)) - - return outputs - - @impl(XLA_LIB, "bgmv", "XLA") def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): inputs = inputs.to(dtype=loras.dtype) @@ -118,20 +90,30 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): if len(loras.shape) == 4: loras = loras.squeeze(axis=1) + jax_import_guard() + kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) + + T, _ = inputs.shape _, L, D = loras.shape - # FIXME: Routing the output from 1 Pallas kernel directly to another results - # in NaN outputs so here we fallback on a reference implementation until the - # bug is fixed. The kernel can be used for either shrink or expand, but not - # both at the same time. - use_reference = True - if use_reference and L < D or not use_reference and D < L: - return ref_bgmv(inputs, loras, idxs) + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU + # register. This has to happen in pytorch, doing it in Jax will lead to NaNs + L1 = L + if LORA_RANK_BLOCK > L or L % LORA_RANK_BLOCK != 0: + L1 = (L // LORA_RANK_BLOCK + 1) * LORA_RANK_BLOCK - jax_import_guard() - kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) + D1 = D + if DIM_BLOCK_SIZE > D or D % DIM_BLOCK_SIZE != 0: + D1 = (D // DIM_BLOCK_SIZE + 1) * DIM_BLOCK_SIZE + + T1 = T + if TOKENS_BLOCK > T or T % TOKENS_BLOCK != 0: + T1 = (T // TOKENS_BLOCK + 1) * TOKENS_BLOCK + + loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0)) + inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T)) - return kernel(idxs, inputs, loras) + return kernel(idxs, inputs, loras)[:T, :L] @impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 406b557827ba..d47b335065f1 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -192,7 +192,7 @@ def add_lora_linear(self, ).squeeze(0)+lora_bias_stacked[i] Args: - y (torch.Tensor): Output tensor. Will be changed in-place. + y (torch.Tensor): Output tensor. Will not be changed in-place. x (torch.Tensor): Input tensor (T, E) lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. @@ -237,7 +237,7 @@ def add_lora_logits(self, buffer: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: """ - Applies lora specifically for LogitsProcessorWithLoRA. + Applies lora specifically for LogitsProcessorWithLoRA. Semantics: buffer = (x @ lora_a_stacked) * scale From 5a277852dbe8f83aa5fb26937cbea92daad55ab0 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 7 Mar 2025 11:35:46 +0000 Subject: [PATCH 078/186] Removed v0 TPU LoRA implementation Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_model_runner.py | 200 +++----------------------------- vllm/worker/tpu_worker.py | 21 +--- 2 files changed, 18 insertions(+), 203 deletions(-) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index bb973f883248..53541a2579ed 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -3,8 +3,8 @@ import enum import time from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, - Tuple, Type, Union) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, Union) from unittest.mock import patch import numpy as np @@ -17,12 +17,8 @@ from vllm.config import VllmConfig from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models import supports_lora from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) @@ -66,7 +62,6 @@ class ModelInputForTPU(ModelRunnerInputBase): num_samples: int n: List[int] seq_groups: List[List[int]] - lora_inputs: List[Tuple[Set[LoRARequest], LoRAMapping]] is_first_multi_step: bool = True is_last_step: bool = True virtual_engine: int = 0 @@ -77,7 +72,6 @@ def as_broadcastable_tensor_dict( tensor_dict = { "token_ids": self.token_ids, "position_ids": self.position_ids, - "lora_inputs": self.lora_inputs, "input_lens": self.input_lens, "t": self.t, "p": self.p, @@ -129,9 +123,6 @@ def __init__( ) self.cached_step_outputs: List[torch.Tensor] = [] - # LoRA support - self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - smem_size = 512 * 1024 block_table_size = 4 * self.block_tables.size if block_table_size >= smem_size: @@ -163,29 +154,8 @@ def load_model(self) -> None: model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() - self.model = model - - if self.lora_config: - assert supports_lora( - self.model - ), f"{self.model.__class__.__name__} does not support LoRA yet." - - max_pos_embeddings = self.model.config.max_position_embeddings - - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.model_config.get_vocab_size(), - self.lora_config, - self.device, - self.model.embedding_modules, - self.model.embedding_padding_modules, - max_position_embeddings=max_pos_embeddings, - ) - self.model = self.lora_manager.create_lora_manager(self.model) - - self.model = ModelWrapper(self.model) - self.model = torch.compile(self.model, + model = ModelWrapper(model) + self.model = torch.compile(model, backend="openxla", fullgraph=True, dynamic=False) @@ -281,29 +251,6 @@ def _dummy_run( p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 - # Create a series of dummy loras and requests for them. - # Make to fill all lora slots. - if self.lora_config: - dummy_lora_requests: Set[LoRARequest] = set() - dummy_lora_mapping: LoRAMapping - - assert self.lora_manager is not None - with self.lora_manager.dummy_lora_cache(): - for lora_id in range(1, self.lora_config.max_loras + 1): - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora( - dummy_lora_request, - rank=self.lora_config.max_lora_rank) - dummy_lora_requests.add(dummy_lora_request) - dummy_lora_mapping = LoRAMapping( - [lora_id] * batch_size * seq_len, [lora_id] * batch_size, - is_prefill=exec_mode.is_prefill()) - self.set_active_loras(dummy_lora_requests, dummy_lora_mapping) - # NOTE(woosuk): There are two stages of compilation: torch.compile and # XLA compilation. Using `mark_dynamic` can reduce the torch.compile # overhead by reusing the FX graph for different shapes. @@ -312,25 +259,19 @@ def _dummy_run( # in the first run, but can be skipped afterwards as we cache the XLA # graphs in the disk (VLLM_XLA_CACHE_PATH). if exec_mode.is_prefill(): - # Prefill - if self.lora_config is not None: - torch._dynamo.config.capture_dynamic_output_shape_ops = True - else: - torch._dynamo.mark_dynamic(token_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) + # Prefll + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) else: # Decode - if self.lora_config is not None: - torch._dynamo.config.capture_dynamic_output_shape_ops = True - else: - torch._dynamo.mark_dynamic(token_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(input_lens, 0) - torch._dynamo.mark_dynamic(t, 0) + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(input_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. with set_forward_context(attn_metadata, self.vllm_config, 0): @@ -388,7 +329,7 @@ def warmup_model( # Decode start = time.time() seq_len = 1 - batch_size = _get_padded_batch_size(1) + batch_size = 8 # Must be in sync with _get_padded_batch_size() while True: self._dummy_run(batch_size, seq_len, @@ -641,77 +582,8 @@ def prepare_model_input( list(metadata.seq_data.keys()) for metadata in seq_group_metadata_list ] - - lora_inputs = [] - if self.load_config is not None: - lora_inputs = self._prepare_lora_input(seq_group_metadata_list, - is_prompt, - padded_batch_size) - - return ModelInputForTPU(token_ids=input_tokens, - position_ids=input_positions, - attn_metadata=attn_metadata, - input_lens=input_lens, - t=t, - p=p, - num_samples=num_samples, - n=n, - seq_groups=seq_groups, - lora_inputs=lora_inputs) - - def _prepare_lora_input( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - is_prefill: bool, padded_batch_size: int - ) -> List[Tuple[Set[LoRARequest], LoRAMapping]]: - """ - Prepares a list of LoRA inputs. If we're decoding then the list will - only have 1 item, otherwise there'll be an item for each sequence - """ - - lora_input = [] - if is_prefill: - for seq in seq_group_metadata_list: - lora_id = seq.lora_int_id - query_len = seq.token_chunk_size - padded_query_len = _get_padded_prefill_len(query_len) - - index_mapping = [lora_id] * padded_query_len - prompt_mapping = [lora_id] - - lora_request = set() - if seq.lora_request is not None: - lora_request.add(seq.lora_request) - - lora_input.append( - (lora_request, - LoRAMapping(index_mapping=tuple(index_mapping), - prompt_mapping=tuple(prompt_mapping), - is_prefill=True))) - else: - lora_request = set() - index_mapping = [] - prompt_mapping = [] - for seq in seq_group_metadata_list: - lora_id = seq.lora_int_id - - index_mapping += [lora_id] - prompt_mapping += [lora_id] - - if seq.lora_request is not None: - lora_request.add(seq.lora_request) - - index_mapping += [0] * (padded_batch_size - - len(seq_group_metadata_list)) - prompt_mapping += [0] * (padded_batch_size - - len(seq_group_metadata_list)) - - lora_input.append( - (lora_request, - LoRAMapping(index_mapping=tuple(index_mapping), - prompt_mapping=tuple(prompt_mapping), - is_prefill=False))) - - return lora_input + return ModelInputForTPU(input_tokens, input_positions, attn_metadata, + input_lens, t, p, num_samples, n, seq_groups) def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: @@ -728,7 +600,6 @@ def execute_model( num_steps: int = 1, ) -> List[SamplerOutput]: assert intermediate_tensors is None - if not model_input.is_first_multi_step: if not model_input.is_last_step: return [] @@ -804,12 +675,6 @@ def execute_model( input_lens = model_input.input_lens[i:i + 1].to(self.device) t = model_input.t[i:i + 1].to(self.device) p = model_input.p[i:i + 1].to(self.device) - - if self.lora_config is not None: - assert len(model_input.lora_inputs) == batch_size - lora_requests, lora_mapping = model_input.lora_inputs[i] - self.set_active_loras(lora_requests, lora_mapping) - with set_forward_context(model_input.attn_metadata, self.vllm_config, model_input.virtual_engine): @@ -859,12 +724,6 @@ def execute_model( t = model_input.t.to(self.device) p = model_input.p.to(self.device) input_lens = model_input.input_lens.to(self.device) - - if self.lora_config is not None: - assert len(model_input.lora_inputs) == 1 - lora_requests, lora_mapping = model_input.lora_inputs[0] - self.set_active_loras(lora_requests, lora_mapping) - for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping with set_forward_context(model_input.attn_metadata, @@ -907,37 +766,6 @@ def execute_model( model_input.seq_groups) return [sampler_output] - def remove_all_loras(self): - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_adapters() - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - - def add_lora(self, lora_request: LoRARequest) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_adapter(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_adapter(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.pin_adapter(lora_id) - - def list_loras(self) -> Set[int]: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_adapters() - class ModelWrapper(nn.Module): diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index c73327b7c8db..1a5eaba09b94 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import List, Optional, Set, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch_xla.core.xla_model as xm @@ -13,18 +13,18 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoRANotSupportedWorkerBase, WorkerBase, WorkerInput) logger = init_logger(__name__) -class TPUWorker(LocalOrDistributedWorkerBase): +class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): def __init__( self, @@ -85,7 +85,6 @@ def init_device(self) -> None: # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and # 30-40 graphs for decode. 128 is an arbitrary safe number. torch._dynamo.config.cache_size_limit = 128 - torch._dynamo.config.reorderable_logging_functions = set([print]) # Use persistent cache to avoid XLA recompilation. # NOTE(woosuk): Set per-rank cache path since different ranks # can have slightly different XLA graphs. @@ -289,18 +288,6 @@ def execute_worker(self, worker_input: WorkerInput) -> None: attn_backend.copy_blocks(self.tpu_cache, (src_indices, dst_indices)) - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_runner.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.model_runner.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - return self.model_runner.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.model_runner.list_loras() - def _make_src_to_dst( mapping: List[Tuple[int, int]], From 5d15fbcf3e12247ea4bc3a100a1b929e67f92abe Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Sat, 8 Mar 2025 10:25:12 +0000 Subject: [PATCH 079/186] Fixed VocabParallelEmbeddingWithLoRA compilation error Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 4 ++++ vllm/lora/punica_wrapper/punica_tpu.py | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 912f8aef8a26..7ffff29c4267 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -239,6 +239,10 @@ def set_lora( def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 embeddings_indices = self.punica_wrapper.embeddings_indices + + if current_platform.is_tpu(): + embeddings_indices = embeddings_indices[:, :x.size(0)] + indices = embeddings_indices[1].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index d47b335065f1..4b0d760767f4 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -28,6 +28,15 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, # Not all of them are used by the TPU so only convert the useful ones. self._token_lora_indices = self._token_lora_indices.to( dtype=torch.int32) + torch._dynamo.mark_dynamic(self._embeddings_indices, 1) + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + return self._embeddings_indices[:] def shrink( self, From ca3d8107c495e1deacf71220a57ba34ebfac7728 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 10 Mar 2025 11:47:31 +0000 Subject: [PATCH 080/186] Fixed LogitsProcessorWithLoRA layer compilation issue Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 4b0d760767f4..85a554729937 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -28,6 +28,8 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, # Not all of them are used by the TPU so only convert the useful ones. self._token_lora_indices = self._token_lora_indices.to( dtype=torch.int32) + self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) + self._sampler_indices_padded = self._sampler_indices_padded.to(dtype=torch.int32) torch._dynamo.mark_dynamic(self._embeddings_indices, 1) @property @@ -260,6 +262,9 @@ def add_lora_logits(self, scale (float): Scaling factor. buffer (Optional[torch.Tensor]):Default to None. """ + if self.no_lora: + return y + y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) @@ -270,9 +275,8 @@ def add_lora_logits(self, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - # LogitsProcessorWithLoRA always using bgmv. - buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, - scale) + + buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) y = bgmv_expand(buffer, lora_b_stacked, y, From 12f71cec70f97678cbf07aec736b0314eb1918da Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 10 Mar 2025 11:47:41 +0000 Subject: [PATCH 081/186] Slightly sped up the kernel Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index e7174b6c1e4e..9244e565955f 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -110,8 +110,10 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): if TOKENS_BLOCK > T or T % TOKENS_BLOCK != 0: T1 = (T // TOKENS_BLOCK + 1) * TOKENS_BLOCK - loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0)) - inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T)) + if D1 != D or L1 != L: + loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0)) + if D1 != D or T1 != T: + inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T)) return kernel(idxs, inputs, loras)[:T, :L] @@ -120,6 +122,10 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): T, _ = inputs.shape - _, _, L, _ = loras.shape + + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + _, L, _ = loras.shape return torch.empty((T, L), device=inputs.device) From d040ee81176d35ea6177a438d37a739310673a1f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 10 Mar 2025 12:00:23 +0000 Subject: [PATCH 082/186] Lint Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 4 ++-- vllm/lora/ops/xla_ops/pallas.py | 4 ++-- vllm/lora/punica_wrapper/punica_tpu.py | 10 ++++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 7ffff29c4267..94c83c0c0157 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -239,10 +239,10 @@ def set_lora( def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 embeddings_indices = self.punica_wrapper.embeddings_indices - + if current_platform.is_tpu(): embeddings_indices = embeddings_indices[:, :x.size(0)] - + indices = embeddings_indices[1].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 9244e565955f..8e8c12104c6a 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -122,10 +122,10 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): T, _ = inputs.shape - + if len(loras.shape) == 4: loras = loras.squeeze(axis=1) - + _, L, _ = loras.shape return torch.empty((T, L), device=inputs.device) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 85a554729937..6a39b5cb7c86 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -29,9 +29,10 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self._token_lora_indices = self._token_lora_indices.to( dtype=torch.int32) self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) - self._sampler_indices_padded = self._sampler_indices_padded.to(dtype=torch.int32) + self._sampler_indices_padded = self._sampler_indices_padded.to( + dtype=torch.int32) torch._dynamo.mark_dynamic(self._embeddings_indices, 1) - + @property def embeddings_indices(self) -> torch.Tensor: """ @@ -275,8 +276,9 @@ def add_lora_logits(self, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - - buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + + buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, + scale) y = bgmv_expand(buffer, lora_b_stacked, y, From e696144886b7c15aa5b2548f0c5deb72112c6039 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 10 Mar 2025 15:57:00 +0000 Subject: [PATCH 083/186] Fixed bug with higher batch sizes Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 1296809a478c..9ba660c08b68 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -438,8 +438,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) if self.lora_config is not None: - self.set_active_loras(self.input_batch, - np.array(padded_total_num_scheduled_tokens)) + self.set_active_loras(self.input_batch, + num_scheduled_tokens_per_req) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, From d110613754c126fb49a39d0dc6f70829de9fb5ed Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 10 Mar 2025 16:05:13 +0000 Subject: [PATCH 084/186] Lint Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 9ba660c08b68..845e466cb452 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -438,7 +438,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) if self.lora_config is not None: - self.set_active_loras(self.input_batch, + self.set_active_loras(self.input_batch, num_scheduled_tokens_per_req) attn_metadata = PallasMetadata( From f8d5da2f6c35db84cb07f6961337a14a1df5f7cb Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 11 Mar 2025 11:28:13 +0000 Subject: [PATCH 085/186] Removed TODO in bgmv pallas test Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index dbe8fcbfcb9b..3490246d5991 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -59,15 +59,15 @@ def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", [0]) -def test_bgmv(T, D, L, N, dtype, op_type, seed): +def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): + if op_type == "expand": + D, L = L, D + inputs, loras, idxs, ref_output = generate_test_data( T, D, L, N, seed, dtype) # Run bgmv - if op_type == "expand": - output = torch.ops.xla.bgmv(inputs, loras, idxs) # TODO: Specialise - else: - output = torch.ops.xla.bgmv(inputs, loras, idxs) + output = torch.ops.xla.bgmv(inputs, loras, idxs) # Make sure we have no NaNs assert not torch.any(torch.isnan(output)) From d11437717c6f73c70d4867c14d003d068a571dc5 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 11 Mar 2025 11:32:31 +0000 Subject: [PATCH 086/186] Fixed PunicaWrapperBase typing Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 0332867055b7..570cd1b756a9 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -342,7 +342,7 @@ def update_metadata( @abstractmethod def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, **kwargs) -> None: + scale: float, **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. @@ -369,7 +369,7 @@ def add_expand(self, output_slices: Tuple[int, ...], offset_start: int = 0, add_inputs=True, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM and bias addition for multiple slices of lora_b. @@ -401,7 +401,7 @@ def add_lora_embedding(self, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. and this layer only requires the expand operation. @@ -428,7 +428,7 @@ def add_lora_linear(self, output_slices: Tuple[int, ...], *, buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. @@ -463,7 +463,7 @@ def add_lora_logits(self, scale, *, buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. From 430bae9d5c5faa5c7dee6e330ff0f7a869b4a179 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 11 Mar 2025 13:14:13 +0000 Subject: [PATCH 087/186] Fixed bug where vLLM crashes on decode Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 10 +++++----- vllm/v1/worker/tpu_model_runner.py | 9 ++++++++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 5f051575d3fc..f1b642e5852c 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -12,7 +12,7 @@ def bgmv_expand(inputs: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - batch_size = outputs.size(0) + n_tokens = outputs.size(0) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -20,7 +20,7 @@ def bgmv_expand(inputs: torch.Tensor, outputs = torch.cat( (outputs, - torch.zeros((batch_size, output_tensor.shape[1] - outputs.shape[1]), + torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]), device=outputs.device)), dim=1) @@ -48,13 +48,13 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_size: int, add_inputs: bool = True): outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - batch_size = outputs.size(0) + n_tokens = outputs.size(0) outputs = torch.cat(( - torch.zeros((batch_size, slice_offset), device=outputs.device), + torch.zeros((n_tokens, slice_offset), device=outputs.device), outputs, torch.zeros( - (batch_size, output_tensor.shape[1] - (slice_offset + slice_size)), + (n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)), device=outputs.device), ), dim=1) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 845e466cb452..887a0fdce2a8 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -438,8 +438,15 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) if self.lora_config is not None: + # We need to respect padding when activating LoRA adapters + padded_num_scheduled_tokens_per_req = np.copy( + num_scheduled_tokens_per_req + ) # Copying to avoid accidental state corruption bugs + padded_num_scheduled_tokens_per_req[ + -1] += padded_total_num_scheduled_tokens - total_num_scheduled_tokens + self.set_active_loras(self.input_batch, - num_scheduled_tokens_per_req) + padded_num_scheduled_tokens_per_req) attn_metadata = PallasMetadata( slot_mapping=slot_mapping, From fb36fd6952a7843b563a414ef29a1a42a4e9f864 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 11 Mar 2025 14:16:07 +0000 Subject: [PATCH 088/186] Fixed NaN bug with LogitsProcessor Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 2 +- vllm/lora/ops/xla_ops/pallas.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 94c83c0c0157..ec40ff6edc13 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1159,7 +1159,7 @@ def _get_logits( self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits - # LogitsProcessorWithLoRA always using bgmv + # LogitsProcessorWithLoRA always uses bgmv lora_output = self.punica_wrapper.add_lora_logits( logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 8e8c12104c6a..35dc307539bf 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -114,6 +114,8 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0)) if D1 != D or T1 != T: inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T)) + if T1 != T: + idxs = torch.nn.functional.pad(idxs, ((0, T1 - T))) return kernel(idxs, inputs, loras)[:T, :L] From 23b14d137572ddf7dba1a84665fc6dbdc12d04df Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Mar 2025 11:10:05 +0000 Subject: [PATCH 089/186] Updated LoRALogitsProcessor to work with the TPU Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 4 ++-- vllm/v1/worker/tpu_model_runner.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index ec40ff6edc13..7a40b4054993 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1151,8 +1151,8 @@ def _get_logits( posinf=pos_inf, neginf=neg_inf)) - # HPU needs special handling to prune out dummy samples. - if current_platform.is_hpu(): + # TPU/HPU needs special handling to prune out dummy samples. + if current_platform.is_hpu() or current_platform.is_tpu(): lora_logits = lora_logits[:logits.shape[0], :] logits[:, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index bb7250ee78e8..db7e407611c4 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -780,7 +780,7 @@ def capture_model(self) -> None: start = time.perf_counter() num_tokens = 16 while True: - with self.maybe_profile_with_lora( + with self.maybe_dummy_run_with_lora( self.lora_config, np.array([num_tokens], dtype=np.int32)): self._dummy_run(self.kv_caches, num_tokens) logger.info(" -- num_tokens: %d", num_tokens) From 27d6f707401cc384f6d95c5139a6dc509539f233 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Mar 2025 11:13:01 +0000 Subject: [PATCH 090/186] Lint Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index db7e407611c4..ffe711dbddf8 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -437,8 +437,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): padded_num_scheduled_tokens_per_req = np.copy( num_scheduled_tokens_per_req ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[ - -1] += padded_total_num_scheduled_tokens - total_num_scheduled_tokens + padded_num_scheduled_tokens_per_req[-1] += \ + padded_total_num_scheduled_tokens - total_num_scheduled_tokens self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req) From b5472718dcd969b9eea215305b2eb38133a7591a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Mar 2025 16:09:51 +0000 Subject: [PATCH 091/186] Fixed batched logits processing Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 7 +++++-- vllm/lora/punica_wrapper/punica_tpu.py | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 7a40b4054993..bb2dd63c8b62 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1144,6 +1144,9 @@ def _get_logits( lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded + if current_platform.is_tpu(): + indices_padded = indices_padded[:logits.size(0)] + lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], @@ -1151,8 +1154,8 @@ def _get_logits( posinf=pos_inf, neginf=neg_inf)) - # TPU/HPU needs special handling to prune out dummy samples. - if current_platform.is_hpu() or current_platform.is_tpu(): + # HPU needs special handling to prune out dummy samples. + if current_platform.is_hpu(): lora_logits = lora_logits[:logits.shape[0], :] logits[:, diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 6a39b5cb7c86..38b4b393be8b 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -32,6 +32,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self._sampler_indices_padded = self._sampler_indices_padded.to( dtype=torch.int32) torch._dynamo.mark_dynamic(self._embeddings_indices, 1) + torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) @property def embeddings_indices(self) -> torch.Tensor: @@ -41,6 +42,13 @@ def embeddings_indices(self) -> torch.Tensor: """ return self._embeddings_indices[:] + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + return self._sampler_indices_padded[:] + def shrink( self, y: torch.Tensor, From f5138b83388136e6e4d81857a225030ae33bf00e Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Mar 2025 17:28:27 +0000 Subject: [PATCH 092/186] Updated kernel test Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index 3490246d5991..5be0277b82c6 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -5,15 +5,22 @@ # Required to register the custom ops import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import -N_TOKENS = [ - 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, - 131072 -] -HIDDEN_SIZES = [128, 256, 512, 896, 1024, 2048, 4096, 8192, 8320] +# N_TOKENS = [ +# 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, +# 131072 +# ] +# HIDDEN_SIZES = [128, 256, 512, 896, 1024, 2048, 4096, 8192, 8320] -DTYPES = [torch.float16, torch.bfloat16] -NUM_LORA = [1, 2, 4, 8, 16, 32] -RANKS = [8, 16, 32, 64, 128] +# DTYPES = [torch.float16, torch.bfloat16] +# NUM_LORA = [1, 2, 4, 8, 16, 32] +# RANKS = [8, 16, 32, 64, 128] + +N_TOKENS = [2048] +HIDDEN_SIZES = [4096] + +DTYPES = [torch.bfloat16] +NUM_LORA = [1, 2, 4] +RANKS = [8] def generate_test_data(T, D, L, N, seed, dtype=torch.float32): @@ -73,4 +80,4 @@ def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): assert not torch.any(torch.isnan(output)) # Compare with reference output - assert torch.allclose(output, ref_output, rtol=1e-3, atol=1e-3) + assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) From ad1487257ef3c9a9afcc158fd8bf7c67745b6a14 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 12 Mar 2025 17:28:38 +0000 Subject: [PATCH 093/186] Added kernel benchmark (dev only, remove later) Signed-off-by: Akshat Tripathi --- bmark_kernels.py | 92 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 bmark_kernels.py diff --git a/bmark_kernels.py b/bmark_kernels.py new file mode 100644 index 000000000000..c1bc08ced9af --- /dev/null +++ b/bmark_kernels.py @@ -0,0 +1,92 @@ +from functools import partial +import itertools +import pytest + +import torch +import torch_xla.core.xla_model as xm +import vllm.lora.ops.xla_ops.pallas as pl + +def create_tensors(T, D, L, N, dtype=torch.bfloat16, device='xla'): + """ + Inputs: (All integers) + T: Total number of tokens + D: Input dim + L: LoRA Dim + N: N LoRAs + + Outputs: + inputs: torch.Tensor - shape (T, D) + loras: torch.Tensor - shape (N, 1, L, D) + idxs: torch.IntTensor - shape (T, ) - all values must be in [0, N) + """ + + inputs = torch.randn((T, D), dtype=dtype, device=device) + loras = torch.randn((N, L, D), dtype=dtype, device=device) + idxs = torch.randint(0, N, (T,), dtype=torch.int32, device=device) + + return inputs, loras, idxs + +# SEQ_LENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 131072] +# HIDDEN_DIM = [256, 1024, 4096, 8192, 14336, 28672] +# LORA_RANKS = [8, 16, 32, 64, 128, 128] +# N_LORAS = [1, 2, 4, 8] + + +SEQ_LENS = [1024] +HIDDEN_DIM = [4096] +LORA_RANKS = [32] +N_LORAS = [1] + +@torch.compile(fullgraph=True, dynamic=False, backend="openxla") +def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): + return torch.einsum("td,tld->tl", inputs, loras[idxs]) + +@torch.compile(fullgraph=True, dynamic=False, backend="openxla") +def bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): + return torch.ops.xla.bgmv(inputs, loras, idxs) + +@torch.compile(fullgraph=True, dynamic=False, backend="openxla") +def shrink_and_expand(inputs: torch.Tensor, loras_a: torch.Tensor, loras_b: torch.Tensor, idxs: torch.IntTensor): + # TODO: Fuse kernels + return bgmv( + bgmv(inputs, loras_a, idxs), + loras_b, + idxs + ) + +@torch.compile(fullgraph=True, dynamic=False, backend="openxla") +def ref_shrink_and_expand(inputs: torch.Tensor, loras_a: torch.Tensor, loras_b: torch.Tensor, idxs: torch.IntTensor): + return ref_bgmv( + ref_bgmv(inputs, loras_a, idxs), + loras_b, + idxs + ) + +def run_and_wait_torch(func, *args): + out = func(*args) + xm.mark_step() + xm.wait_device_ops() + return out + +@pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) +@pytest.mark.parametrize("func", [bgmv, ref_bgmv]) +def test_bmark_shrink(benchmark, T, D, L, N, func): + inputs, loras, idxs = create_tensors(T, D, L, N) + + benchmark.pedantic(partial(run_and_wait_torch, func), args=(inputs, loras, idxs), rounds=5, warmup_rounds=5, iterations=100) + +@pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, LORA_RANKS, HIDDEN_DIM, N_LORAS)) +@pytest.mark.parametrize("func", [bgmv, ref_bgmv]) +def test_bmark_expand(benchmark, T, D, L, N, func): + inputs, loras, idxs = create_tensors(T, D, L, N) + + benchmark.pedantic(partial(run_and_wait_torch, func), args=(inputs, loras, idxs), rounds=5, warmup_rounds=5, iterations=100) + + +@pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) +@pytest.mark.parametrize("func", [shrink_and_expand, ref_shrink_and_expand]) +def test_bmark_shrink_and_expand(benchmark, T, D, L, N, func): + inputs, loras_a, idxs = create_tensors(T, D, L, N) + _, loras_b, _ = create_tensors(T, L, D, N) + + benchmark.pedantic(partial(run_and_wait_torch, func), args=(inputs, loras_a, loras_b, idxs), rounds=5, warmup_rounds=5, iterations=100) From 7418b5a1eabbd218b22b79328101c3679fa65238 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 13 Mar 2025 14:32:06 +0000 Subject: [PATCH 094/186] Tuned bgmv kernel block sizes Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 61 ++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 35dc307539bf..8aa9bbdd4ac2 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -10,15 +10,8 @@ from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, make_kernel_from_pallas) -# TODO: Tune these -TOKENS_BLOCK = 16 -LORA_RANK_BLOCK = 128 -DIM_BLOCK_SIZE = 128 - - def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, mask_ref): - @pl.when(pl.program_id(2) == 0) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) @@ -40,33 +33,37 @@ def _(): out_ref[...] = acc_ref[...].astype(out_ref.dtype) -@jax.jit +@functools.partial(jax.jit, static_argnames=["TOKEN_BLOCK_SIZE", "LORA_RANK_BLOCK_SIZE", "DIM_BLOCK_SIZE"]) def _bgmv( idxs: jax.Array, # (T, ) int32 inputs: jax.Array, # (T, D) model dtype - loras: jax.Array # (N, L, D) model dtype + loras: jax.Array, # (N, L, D) model dtype + *, + TOKEN_BLOCK_SIZE: int, + LORA_RANK_BLOCK_SIZE: int, + DIM_BLOCK_SIZE: int ) -> jax.Array: # (T, L) model dtype T, D = inputs.shape N, L, _ = loras.shape return pl.pallas_call( - kernel=functools.partial(_bgmv_kernel, TOKENS_BLOCK, LORA_RANK_BLOCK), + kernel=functools.partial(_bgmv_kernel, TOKEN_BLOCK_SIZE, LORA_RANK_BLOCK_SIZE), out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, - grid=(T // TOKENS_BLOCK, L // LORA_RANK_BLOCK, + grid=(T // TOKEN_BLOCK_SIZE, L // LORA_RANK_BLOCK_SIZE, D // DIM_BLOCK_SIZE), in_specs=[ - pl.BlockSpec((TOKENS_BLOCK, DIM_BLOCK_SIZE), + pl.BlockSpec((TOKEN_BLOCK_SIZE, DIM_BLOCK_SIZE), lambda i, j, k, block_idx: (i, k)), - pl.BlockSpec((N, LORA_RANK_BLOCK, DIM_BLOCK_SIZE), + pl.BlockSpec((N, LORA_RANK_BLOCK_SIZE, DIM_BLOCK_SIZE), lambda i, j, k, block_idx: (0, j, k)), ], - out_specs=pl.BlockSpec((TOKENS_BLOCK, LORA_RANK_BLOCK), + out_specs=pl.BlockSpec((TOKEN_BLOCK_SIZE, LORA_RANK_BLOCK_SIZE), lambda i, j, k, block_idx: (i, j)), scratch_shapes=[ - pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32), - pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32) + pltpu.VMEM((TOKEN_BLOCK_SIZE, LORA_RANK_BLOCK_SIZE), jnp.float32), + pltpu.VMEM((TOKEN_BLOCK_SIZE, LORA_RANK_BLOCK_SIZE), jnp.float32) ]), compiler_params=pltpu.TPUCompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), @@ -90,25 +87,43 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): if len(loras.shape) == 4: loras = loras.squeeze(axis=1) - jax_import_guard() - kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) - T, _ = inputs.shape _, L, D = loras.shape + jax_import_guard() + + TOKEN_BLOCK_SIZE=16 + if L > D: # Expand + LORA_RANK_BLOCK_SIZE=1024 + DIM_BLOCK_SIZE=256 + else: # Shrink + LORA_RANK_BLOCK_SIZE=256 + DIM_BLOCK_SIZE=1024 + + + kernel = make_kernel_from_pallas( + functools.partial( + _bgmv, + TOKEN_BLOCK_SIZE=TOKEN_BLOCK_SIZE, + LORA_RANK_BLOCK_SIZE=LORA_RANK_BLOCK_SIZE, + DIM_BLOCK_SIZE=DIM_BLOCK_SIZE + ), + bgmv_shape_function + ) + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU # register. This has to happen in pytorch, doing it in Jax will lead to NaNs L1 = L - if LORA_RANK_BLOCK > L or L % LORA_RANK_BLOCK != 0: - L1 = (L // LORA_RANK_BLOCK + 1) * LORA_RANK_BLOCK + if LORA_RANK_BLOCK_SIZE > L or L % LORA_RANK_BLOCK_SIZE != 0: + L1 = (L // LORA_RANK_BLOCK_SIZE + 1) * LORA_RANK_BLOCK_SIZE D1 = D if DIM_BLOCK_SIZE > D or D % DIM_BLOCK_SIZE != 0: D1 = (D // DIM_BLOCK_SIZE + 1) * DIM_BLOCK_SIZE T1 = T - if TOKENS_BLOCK > T or T % TOKENS_BLOCK != 0: - T1 = (T // TOKENS_BLOCK + 1) * TOKENS_BLOCK + if TOKEN_BLOCK_SIZE > T or T % TOKEN_BLOCK_SIZE != 0: + T1 = (T // TOKEN_BLOCK_SIZE + 1) * TOKEN_BLOCK_SIZE if D1 != D or L1 != L: loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0)) From 2aacb34030c282c4a66eee7f133ce0a893c64351 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 13 Mar 2025 15:00:06 +0000 Subject: [PATCH 095/186] Improved lora output masking Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 8aa9bbdd4ac2..23a19c48b3b7 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -10,22 +10,24 @@ from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, make_kernel_from_pallas) -def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, +def _bgmv_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, mask_ref): @pl.when(pl.program_id(2) == 0) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) t = pl.program_id(0) - - for i in range(bT): - idx = idx_ref[i + bT * t] + + for i in range(max_num_loras): mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) - mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32) + for j in range(bT): + @pl.when(idx_ref[j + bT * t] == i) + def _(): + mask_ref[j, :] = jnp.ones((bL, ), dtype=jnp.float32) acc_ref[...] += jax.lax.dot_general( inp_ref[...], - lora_ref[idx, ...], (((1, ), (1, )), ((), ())), + lora_ref[i, ...], (((1, ), (1, )), ((), ())), preferred_element_type=jnp.float32) * mask_ref[...] @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) @@ -47,7 +49,7 @@ def _bgmv( N, L, _ = loras.shape return pl.pallas_call( - kernel=functools.partial(_bgmv_kernel, TOKEN_BLOCK_SIZE, LORA_RANK_BLOCK_SIZE), + kernel=functools.partial(_bgmv_kernel, TOKEN_BLOCK_SIZE, LORA_RANK_BLOCK_SIZE, N), out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, From 6ee0b57858967e35b84a442121b89a0dff6ef333 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 13 Mar 2025 16:19:07 +0000 Subject: [PATCH 096/186] Skipped matmuls where no loras are needed Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 23a19c48b3b7..fc9c96e4a8f2 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -20,15 +20,20 @@ def _(): for i in range(max_num_loras): mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + valid = False for j in range(bT): + valid |= idx_ref[j + bT * t] == i + @pl.when(idx_ref[j + bT * t] == i) def _(): mask_ref[j, :] = jnp.ones((bL, ), dtype=jnp.float32) - acc_ref[...] += jax.lax.dot_general( - inp_ref[...], - lora_ref[i, ...], (((1, ), (1, )), ((), ())), - preferred_element_type=jnp.float32) * mask_ref[...] + @pl.when(valid) + def _(): + acc_ref[...] += jax.lax.dot_general( + inp_ref[...], + lora_ref[i, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_ref[...] @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) def _(): From d9e415f18a3f18e028d59d47c7bce4352504244c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 17 Mar 2025 00:38:05 +0000 Subject: [PATCH 097/186] Renamed variables for better readabiity Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 52 ++++++++++++++++----------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index fc9c96e4a8f2..fee8c30ba1b3 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -40,37 +40,37 @@ def _(): out_ref[...] = acc_ref[...].astype(out_ref.dtype) -@functools.partial(jax.jit, static_argnames=["TOKEN_BLOCK_SIZE", "LORA_RANK_BLOCK_SIZE", "DIM_BLOCK_SIZE"]) +@functools.partial(jax.jit, static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"]) def _bgmv( idxs: jax.Array, # (T, ) int32 inputs: jax.Array, # (T, D) model dtype loras: jax.Array, # (N, L, D) model dtype *, - TOKEN_BLOCK_SIZE: int, - LORA_RANK_BLOCK_SIZE: int, - DIM_BLOCK_SIZE: int + TOKEN_BLOCK: int, + LORA_BLOCK: int, + DIM_BLOCK: int ) -> jax.Array: # (T, L) model dtype T, D = inputs.shape N, L, _ = loras.shape return pl.pallas_call( - kernel=functools.partial(_bgmv_kernel, TOKEN_BLOCK_SIZE, LORA_RANK_BLOCK_SIZE, N), + kernel=functools.partial(_bgmv_kernel, TOKEN_BLOCK, LORA_BLOCK, N), out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, - grid=(T // TOKEN_BLOCK_SIZE, L // LORA_RANK_BLOCK_SIZE, - D // DIM_BLOCK_SIZE), + grid=(T // TOKEN_BLOCK, L // LORA_BLOCK, + D // DIM_BLOCK), in_specs=[ - pl.BlockSpec((TOKEN_BLOCK_SIZE, DIM_BLOCK_SIZE), + pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), lambda i, j, k, block_idx: (i, k)), - pl.BlockSpec((N, LORA_RANK_BLOCK_SIZE, DIM_BLOCK_SIZE), + pl.BlockSpec((N, LORA_BLOCK, DIM_BLOCK), lambda i, j, k, block_idx: (0, j, k)), ], - out_specs=pl.BlockSpec((TOKEN_BLOCK_SIZE, LORA_RANK_BLOCK_SIZE), + out_specs=pl.BlockSpec((TOKEN_BLOCK, LORA_BLOCK), lambda i, j, k, block_idx: (i, j)), scratch_shapes=[ - pltpu.VMEM((TOKEN_BLOCK_SIZE, LORA_RANK_BLOCK_SIZE), jnp.float32), - pltpu.VMEM((TOKEN_BLOCK_SIZE, LORA_RANK_BLOCK_SIZE), jnp.float32) + pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), + pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32) ]), compiler_params=pltpu.TPUCompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), @@ -99,21 +99,21 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): jax_import_guard() - TOKEN_BLOCK_SIZE=16 + TOKEN_BLOCK=16 if L > D: # Expand - LORA_RANK_BLOCK_SIZE=1024 - DIM_BLOCK_SIZE=256 + LORA_BLOCK=1024 + DIM_BLOCK=256 else: # Shrink - LORA_RANK_BLOCK_SIZE=256 - DIM_BLOCK_SIZE=1024 + LORA_BLOCK=256 + DIM_BLOCK=1024 kernel = make_kernel_from_pallas( functools.partial( _bgmv, - TOKEN_BLOCK_SIZE=TOKEN_BLOCK_SIZE, - LORA_RANK_BLOCK_SIZE=LORA_RANK_BLOCK_SIZE, - DIM_BLOCK_SIZE=DIM_BLOCK_SIZE + TOKEN_BLOCK=TOKEN_BLOCK, + LORA_BLOCK=LORA_BLOCK, + DIM_BLOCK=DIM_BLOCK ), bgmv_shape_function ) @@ -121,16 +121,16 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU # register. This has to happen in pytorch, doing it in Jax will lead to NaNs L1 = L - if LORA_RANK_BLOCK_SIZE > L or L % LORA_RANK_BLOCK_SIZE != 0: - L1 = (L // LORA_RANK_BLOCK_SIZE + 1) * LORA_RANK_BLOCK_SIZE + if LORA_BLOCK > L or L % LORA_BLOCK != 0: + L1 = (L // LORA_BLOCK + 1) * LORA_BLOCK D1 = D - if DIM_BLOCK_SIZE > D or D % DIM_BLOCK_SIZE != 0: - D1 = (D // DIM_BLOCK_SIZE + 1) * DIM_BLOCK_SIZE + if DIM_BLOCK > D or D % DIM_BLOCK != 0: + D1 = (D // DIM_BLOCK + 1) * DIM_BLOCK T1 = T - if TOKEN_BLOCK_SIZE > T or T % TOKEN_BLOCK_SIZE != 0: - T1 = (T // TOKEN_BLOCK_SIZE + 1) * TOKEN_BLOCK_SIZE + if TOKEN_BLOCK > T or T % TOKEN_BLOCK != 0: + T1 = (T // TOKEN_BLOCK + 1) * TOKEN_BLOCK if D1 != D or L1 != L: loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0)) From 460e8081bf2dda639d117ccada1df32b7fe97079 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 17 Mar 2025 00:48:01 +0000 Subject: [PATCH 098/186] Moved inner loop into grid spec Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 51 +++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index fee8c30ba1b3..fd02f659440f 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -10,32 +10,39 @@ from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, make_kernel_from_pallas) -def _bgmv_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, out_ref, +def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, mask_ref): - @pl.when(pl.program_id(2) == 0) + + t = pl.program_id(0) + + d = pl.program_id(2) + ds = pl.num_programs(2) + + lora_idx = pl.program_id(3) + n_lora_idxs = pl.num_programs(3) + + @pl.when((d == 0) & (lora_idx == 0)) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) - t = pl.program_id(0) + valid = False + for j in range(bT): + valid |= idx_ref[j + bT * t] == lora_idx - for i in range(max_num_loras): + @pl.when(valid) + def _(): mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) - valid = False for j in range(bT): - valid |= idx_ref[j + bT * t] == i - - @pl.when(idx_ref[j + bT * t] == i) + @pl.when(idx_ref[j + bT * t] == lora_idx) def _(): mask_ref[j, :] = jnp.ones((bL, ), dtype=jnp.float32) - @pl.when(valid) - def _(): - acc_ref[...] += jax.lax.dot_general( - inp_ref[...], - lora_ref[i, ...], (((1, ), (1, )), ((), ())), - preferred_element_type=jnp.float32) * mask_ref[...] + acc_ref[...] += jax.lax.dot_general( + inp_ref[...], + lora_ref[0, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_ref[...] - @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + @pl.when((d == ds - 1) & (lora_idx == n_lora_idxs - 1)) def _(): out_ref[...] = acc_ref[...].astype(out_ref.dtype) @@ -54,26 +61,26 @@ def _bgmv( N, L, _ = loras.shape return pl.pallas_call( - kernel=functools.partial(_bgmv_kernel, TOKEN_BLOCK, LORA_BLOCK, N), + kernel=functools.partial(_bgmv_kernel, TOKEN_BLOCK, LORA_BLOCK), out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, grid=(T // TOKEN_BLOCK, L // LORA_BLOCK, - D // DIM_BLOCK), + D // DIM_BLOCK, N), in_specs=[ pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), - lambda i, j, k, block_idx: (i, k)), - pl.BlockSpec((N, LORA_BLOCK, DIM_BLOCK), - lambda i, j, k, block_idx: (0, j, k)), + lambda t, l, d, n, block_idx: (t, d)), + pl.BlockSpec((1, LORA_BLOCK, DIM_BLOCK), + lambda t, l, d, n, block_idx: (n, l, d)), ], out_specs=pl.BlockSpec((TOKEN_BLOCK, LORA_BLOCK), - lambda i, j, k, block_idx: (i, j)), + lambda t, l, d, n, block_idx: (t, l)), scratch_shapes=[ pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32) ]), compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=("parallel", "parallel", "arbitrary")), + dimension_semantics=("parallel", "parallel", "arbitrary", "arbitrary")), name="bgmv")(idxs, inputs, loras) From 12ac3b809c84872615eaf02054f68753857bbf4d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 17 Mar 2025 16:22:40 +0000 Subject: [PATCH 099/186] Revert "Moved inner loop into grid spec" This reverts commit 460e8081bf2dda639d117ccada1df32b7fe97079. Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 51 ++++++++++++++------------------- 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index fd02f659440f..fee8c30ba1b3 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -10,39 +10,32 @@ from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, make_kernel_from_pallas) -def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, +def _bgmv_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, mask_ref): - - t = pl.program_id(0) - - d = pl.program_id(2) - ds = pl.num_programs(2) - - lora_idx = pl.program_id(3) - n_lora_idxs = pl.num_programs(3) - - @pl.when((d == 0) & (lora_idx == 0)) + @pl.when(pl.program_id(2) == 0) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) - valid = False - for j in range(bT): - valid |= idx_ref[j + bT * t] == lora_idx + t = pl.program_id(0) - @pl.when(valid) - def _(): + for i in range(max_num_loras): mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + valid = False for j in range(bT): - @pl.when(idx_ref[j + bT * t] == lora_idx) + valid |= idx_ref[j + bT * t] == i + + @pl.when(idx_ref[j + bT * t] == i) def _(): mask_ref[j, :] = jnp.ones((bL, ), dtype=jnp.float32) - acc_ref[...] += jax.lax.dot_general( - inp_ref[...], - lora_ref[0, ...], (((1, ), (1, )), ((), ())), - preferred_element_type=jnp.float32) * mask_ref[...] + @pl.when(valid) + def _(): + acc_ref[...] += jax.lax.dot_general( + inp_ref[...], + lora_ref[i, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_ref[...] - @pl.when((d == ds - 1) & (lora_idx == n_lora_idxs - 1)) + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) def _(): out_ref[...] = acc_ref[...].astype(out_ref.dtype) @@ -61,26 +54,26 @@ def _bgmv( N, L, _ = loras.shape return pl.pallas_call( - kernel=functools.partial(_bgmv_kernel, TOKEN_BLOCK, LORA_BLOCK), + kernel=functools.partial(_bgmv_kernel, TOKEN_BLOCK, LORA_BLOCK, N), out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, grid=(T // TOKEN_BLOCK, L // LORA_BLOCK, - D // DIM_BLOCK, N), + D // DIM_BLOCK), in_specs=[ pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), - lambda t, l, d, n, block_idx: (t, d)), - pl.BlockSpec((1, LORA_BLOCK, DIM_BLOCK), - lambda t, l, d, n, block_idx: (n, l, d)), + lambda i, j, k, block_idx: (i, k)), + pl.BlockSpec((N, LORA_BLOCK, DIM_BLOCK), + lambda i, j, k, block_idx: (0, j, k)), ], out_specs=pl.BlockSpec((TOKEN_BLOCK, LORA_BLOCK), - lambda t, l, d, n, block_idx: (t, l)), + lambda i, j, k, block_idx: (i, j)), scratch_shapes=[ pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32) ]), compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=("parallel", "parallel", "arbitrary", "arbitrary")), + dimension_semantics=("parallel", "parallel", "arbitrary")), name="bgmv")(idxs, inputs, loras) From af15bd1b16e2bbef42bd93fe6e7d49a46dc024a6 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Mar 2025 16:58:14 +0000 Subject: [PATCH 100/186] Added comment Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 54e760b436d9..fdc411e52033 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -267,6 +267,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings, self.lora_b_stacked, add_input=True) + + # lora_output is None if the platform supports inplace updates. + # Otherwise it's a tensor, so we update the output manually if not current_platform.can_update_inplace(): full_output = lora_output From 41555d1fc3afc15cba013e472e68a1f0836f3e35 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Mar 2025 17:05:12 +0000 Subject: [PATCH 101/186] Lint Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index fdc411e52033..32eac38ecade 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -267,8 +267,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings, self.lora_b_stacked, add_input=True) - - # lora_output is None if the platform supports inplace updates. + + # lora_output is None if the platform supports inplace updates. # Otherwise it's a tensor, so we update the output manually if not current_platform.can_update_inplace(): full_output = lora_output From 4ac7aa9ac2c1310404ceeb4baf7a89b500927e8f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Mar 2025 20:09:53 +0000 Subject: [PATCH 102/186] Added a fused shrink/expand kernel Signed-off-by: Akshat Tripathi --- bmark_kernels.py | 6 +- tests/lora/tpu/test_pallas_kernels.py | 23 +++ vllm/lora/ops/xla_ops/pallas.py | 192 ++++++++++++++++++++++++-- 3 files changed, 206 insertions(+), 15 deletions(-) diff --git a/bmark_kernels.py b/bmark_kernels.py index c1bc08ced9af..409c36f32bff 100644 --- a/bmark_kernels.py +++ b/bmark_kernels.py @@ -47,9 +47,9 @@ def bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): @torch.compile(fullgraph=True, dynamic=False, backend="openxla") def shrink_and_expand(inputs: torch.Tensor, loras_a: torch.Tensor, loras_b: torch.Tensor, idxs: torch.IntTensor): - # TODO: Fuse kernels - return bgmv( - bgmv(inputs, loras_a, idxs), + return torch.ops.xla.fused_bgmv_shrink_expand( + inputs, + loras_a, loras_b, idxs ) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index 5be0277b82c6..946e271cef11 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -81,3 +81,26 @@ def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): # Compare with reference output assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) + +# Parameterize tests with various shapes and dtypes +@pytest.mark.parametrize("T", N_TOKENS) +@pytest.mark.parametrize("D", HIDDEN_SIZES) +@pytest.mark.parametrize("L", RANKS) +@pytest.mark.parametrize("N", NUM_LORA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", [0]) +def test_fused_bgmv_shrink_expand_correctness(T, D, L, N, dtype, seed): + inputs, loras_a, idxs, intermediate_output = generate_test_data( + T, D, L, N, seed, dtype) + _, loras_b, _, _ = generate_test_data(T, L, D, N, seed, dtype) + + ref_output = ref_bgmv(intermediate_output, loras_b, idxs) + + # Run bgmv + output = torch.ops.xla.fused_bgmv_shrink_expand(inputs, loras_a, loras_b, idxs) + + # Make sure we have no NaNs + assert not torch.any(torch.isnan(output)) + + # Compare with reference output + assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index fee8c30ba1b3..44461d1df1f5 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -12,18 +12,20 @@ def _bgmv_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, mask_ref): - @pl.when(pl.program_id(2) == 0) + t = pl.program_id(0) + d = pl.program_id(2) + ds = pl.num_programs(2) + + @pl.when(d == 0) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) - t = pl.program_id(0) - for i in range(max_num_loras): mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) valid = False for j in range(bT): valid |= idx_ref[j + bT * t] == i - + @pl.when(idx_ref[j + bT * t] == i) def _(): mask_ref[j, :] = jnp.ones((bL, ), dtype=jnp.float32) @@ -35,7 +37,7 @@ def _(): lora_ref[i, ...], (((1, ), (1, )), ((), ())), preferred_element_type=jnp.float32) * mask_ref[...] - @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + @pl.when(d == ds - 1) def _(): out_ref[...] = acc_ref[...].astype(out_ref.dtype) @@ -62,12 +64,12 @@ def _bgmv( D // DIM_BLOCK), in_specs=[ pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), - lambda i, j, k, block_idx: (i, k)), + lambda t, l, d, block_idx: (t, d)), pl.BlockSpec((N, LORA_BLOCK, DIM_BLOCK), - lambda i, j, k, block_idx: (0, j, k)), + lambda t, l, d, block_idx: (0, l, d)), ], out_specs=pl.BlockSpec((TOKEN_BLOCK, LORA_BLOCK), - lambda i, j, k, block_idx: (i, j)), + lambda t, l, d, block_idx: (t, l)), scratch_shapes=[ pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32) @@ -77,16 +79,118 @@ def _bgmv( name="bgmv")(idxs, inputs, loras) +def _fused_bgmv_shrink_expand_kernel( + bT: int, bL: int, bD: int, max_num_loras: int, + idx_ref, inp_ref, lora_a_ref, lora_b_ref, out_ref, cache_ref, acc_ref, mask_a_ref, mask_b_ref): + + t = pl.program_id(0) + d1 = pl.program_id(1) + l = pl.program_id(2) + d2 = pl.program_id(3) + + ls = pl.num_programs(2) + ds = pl.num_programs(3) + + should_compute_cache = d1 == 0 + @pl.when(should_compute_cache) + def _(): + @pl.when(d2 == 0) + def _(): + cache_ref[l, ...] = jnp.zeros_like(cache_ref[l, ...], dtype=jnp.float32) + + for i in range(max_num_loras): + mask_a_ref[...] = jnp.zeros_like(mask_a_ref[...], dtype=jnp.float32) + valid = False + for j in range(bT): + valid |= idx_ref[j + bT * t] == i + + @pl.when(idx_ref[j + bT * t] == i) + def _(): + mask_a_ref[j, :] = jnp.ones((bL, ), dtype=jnp.float32) + + @pl.when(valid) + def _(): + cache_ref[l, ...] += jax.lax.dot_general( + inp_ref[...], + lora_a_ref[i, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_a_ref[...] + + cache_valid = d2 == (ds - 1) + @pl.when(cache_valid) + def _(): + @pl.when(l == 0) + def _(): + acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) + + for i in range(max_num_loras): + mask_b_ref[...] = jnp.zeros_like(mask_b_ref[...], dtype=jnp.float32) + valid = False + for j in range(bT): + valid |= idx_ref[j + bT * t] == i + + @pl.when(idx_ref[j + bT * t] == i) + def _(): + mask_b_ref[j, :] = jnp.ones((bD, ), dtype=jnp.float32) + + @pl.when(valid) + def _(): + acc_ref[...] += jax.lax.dot_general( + cache_ref[l, ...], + lora_b_ref[i, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_b_ref[...] + + @pl.when(l == ls - 1) + def _(): + out_ref[...] = acc_ref[...].astype(out_ref.dtype) + +@functools.partial(jax.jit, static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"]) +def _fused_bgmv_shrink_expand( + idxs: jax.Array, # (T, ) int32 + inputs: jax.Array, # (T, D) model dtype + loras_a: jax.Array, # (N, L, D) model dtype + loras_b: jax.Array, # (N, D, L) model dtype + *, + TOKEN_BLOCK: int, + LORA_BLOCK: int, + DIM_BLOCK: int +) -> jax.Array: # (T, D) model dtype + T, D = inputs.shape + N, L, _ = loras_a.shape + + return pl.pallas_call( + kernel=functools.partial(_fused_bgmv_shrink_expand_kernel, TOKEN_BLOCK, LORA_BLOCK, DIM_BLOCK, N), + out_shape=jax.ShapeDtypeStruct((T, D), dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T // TOKEN_BLOCK, D // DIM_BLOCK, L // LORA_BLOCK, D // DIM_BLOCK), + in_specs=[ + pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), lambda t, d1, l, d2, block_idx: (t, d2)), # Inputs + pl.BlockSpec((N, LORA_BLOCK, DIM_BLOCK), lambda t, d1, l, d2, block_idx: (0, l, d2)), # LoRA A + pl.BlockSpec((N, DIM_BLOCK, LORA_BLOCK), lambda t, d1, l, d2, block_idx: (0, d1, l)), # LoRA B + ], + out_specs=pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), lambda t, d1, l, d2, block_idx: (t, d1)), + scratch_shapes=[ + pltpu.VMEM((L // LORA_BLOCK, TOKEN_BLOCK, LORA_BLOCK), jnp.float32), # Intermediates cache + pltpu.VMEM((TOKEN_BLOCK, DIM_BLOCK), jnp.float32), # Final accumulator + pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), # LoRA A mask + pltpu.VMEM((TOKEN_BLOCK, DIM_BLOCK), jnp.float32) # LoRA B mask + ] + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary", "arbitrary") + ), + name="fused_bgmv_shrink_expand" + )(idxs, inputs, loras_a, loras_b) + + +XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") + def bgmv_shape_function(idxs, inputs, loras): T, _ = inputs.shape _, L, _ = loras.shape return [((T, L), inputs.dtype)] - -XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", ) - - @impl(XLA_LIB, "bgmv", "XLA") def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): inputs = inputs.to(dtype=loras.dtype) @@ -153,3 +257,67 @@ def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, _, L, _ = loras.shape return torch.empty((T, L), device=inputs.device) + +def fused_bgmv_shrink_expand_shape_function(idxs, inputs, loras_a, loras_b): + return [(inputs.shape, inputs.dtype)] + +XLA_LIB.define("fused_bgmv_shrink_expand(Tensor inputs, Tensor loras_a, Tensor loras_b, Tensor idxs) -> Tensor") + +@impl(XLA_LIB, "fused_bgmv_shrink_expand", "XLA") +def fused_bgmv_shrink_expand_xla(inputs: torch.Tensor, loras_a: torch.Tensor, loras_b: torch.Tensor, idxs: torch.IntTensor): + inputs = inputs.to(dtype=loras_a.dtype) + + if len(loras_a.shape) == 4: + loras_a = loras_a.squeeze(axis=1) + if len(loras_b.shape) == 4: + loras_b = loras_b.squeeze(axis=1) + + T, _ = inputs.shape + _, L, D = loras_a.shape + + jax_import_guard() + + TOKEN_BLOCK=16 + LORA_BLOCK=256 + DIM_BLOCK=1024 + + + kernel = make_kernel_from_pallas( + functools.partial( + _fused_bgmv_shrink_expand, + TOKEN_BLOCK=TOKEN_BLOCK, + LORA_BLOCK=LORA_BLOCK, + DIM_BLOCK=DIM_BLOCK + ), + fused_bgmv_shrink_expand_shape_function + ) + + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU + # register. This has to happen in pytorch, doing it in Jax will lead to NaNs + L1 = L + if LORA_BLOCK > L or L % LORA_BLOCK != 0: + L1 = (L // LORA_BLOCK + 1) * LORA_BLOCK + + D1 = D + if DIM_BLOCK > D or D % DIM_BLOCK != 0: + D1 = (D // DIM_BLOCK + 1) * DIM_BLOCK + + T1 = T + if TOKEN_BLOCK > T or T % TOKEN_BLOCK != 0: + T1 = (T // TOKEN_BLOCK + 1) * TOKEN_BLOCK + + if D1 != D or L1 != L: + loras_a = torch.nn.functional.pad(loras_a, (0, D1 - D, 0, L1 - L, 0, 0)) + loras_b = torch.nn.functional.pad(loras_b, (0, L1 - L, 0, D1 - D, 0, 0)) + if D1 != D or T1 != T: + inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T)) + if T1 != T: + idxs = torch.nn.functional.pad(idxs, ((0, T1 - T))) + + return kernel(idxs, inputs, loras_a, loras_b)[:T, :D] + + +@impl(XLA_LIB, "fused_bgmv_shrink_expand", "CompositeExplicitAutograd") +def fused_bgmv_shrink_expand_non_xla(inputs: torch.Tensor, loras_a: torch.Tensor, loras_b: torch.Tensor, + idxs: torch.IntTensor): + return torch.empty_like(inputs) \ No newline at end of file From 9f5a4977a5058b74f0755075b58e168dfe877493 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Mar 2025 20:10:06 +0000 Subject: [PATCH 103/186] Revert "Added a fused shrink/expand kernel" This reverts commit 4ac7aa9ac2c1310404ceeb4baf7a89b500927e8f. Signed-off-by: Akshat Tripathi --- bmark_kernels.py | 6 +- tests/lora/tpu/test_pallas_kernels.py | 23 --- vllm/lora/ops/xla_ops/pallas.py | 192 ++------------------------ 3 files changed, 15 insertions(+), 206 deletions(-) diff --git a/bmark_kernels.py b/bmark_kernels.py index 409c36f32bff..c1bc08ced9af 100644 --- a/bmark_kernels.py +++ b/bmark_kernels.py @@ -47,9 +47,9 @@ def bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): @torch.compile(fullgraph=True, dynamic=False, backend="openxla") def shrink_and_expand(inputs: torch.Tensor, loras_a: torch.Tensor, loras_b: torch.Tensor, idxs: torch.IntTensor): - return torch.ops.xla.fused_bgmv_shrink_expand( - inputs, - loras_a, + # TODO: Fuse kernels + return bgmv( + bgmv(inputs, loras_a, idxs), loras_b, idxs ) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index 946e271cef11..5be0277b82c6 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -81,26 +81,3 @@ def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): # Compare with reference output assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) - -# Parameterize tests with various shapes and dtypes -@pytest.mark.parametrize("T", N_TOKENS) -@pytest.mark.parametrize("D", HIDDEN_SIZES) -@pytest.mark.parametrize("L", RANKS) -@pytest.mark.parametrize("N", NUM_LORA) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", [0]) -def test_fused_bgmv_shrink_expand_correctness(T, D, L, N, dtype, seed): - inputs, loras_a, idxs, intermediate_output = generate_test_data( - T, D, L, N, seed, dtype) - _, loras_b, _, _ = generate_test_data(T, L, D, N, seed, dtype) - - ref_output = ref_bgmv(intermediate_output, loras_b, idxs) - - # Run bgmv - output = torch.ops.xla.fused_bgmv_shrink_expand(inputs, loras_a, loras_b, idxs) - - # Make sure we have no NaNs - assert not torch.any(torch.isnan(output)) - - # Compare with reference output - assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 44461d1df1f5..fee8c30ba1b3 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -12,20 +12,18 @@ def _bgmv_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, mask_ref): - t = pl.program_id(0) - d = pl.program_id(2) - ds = pl.num_programs(2) - - @pl.when(d == 0) + @pl.when(pl.program_id(2) == 0) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) + t = pl.program_id(0) + for i in range(max_num_loras): mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) valid = False for j in range(bT): valid |= idx_ref[j + bT * t] == i - + @pl.when(idx_ref[j + bT * t] == i) def _(): mask_ref[j, :] = jnp.ones((bL, ), dtype=jnp.float32) @@ -37,7 +35,7 @@ def _(): lora_ref[i, ...], (((1, ), (1, )), ((), ())), preferred_element_type=jnp.float32) * mask_ref[...] - @pl.when(d == ds - 1) + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) def _(): out_ref[...] = acc_ref[...].astype(out_ref.dtype) @@ -64,12 +62,12 @@ def _bgmv( D // DIM_BLOCK), in_specs=[ pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), - lambda t, l, d, block_idx: (t, d)), + lambda i, j, k, block_idx: (i, k)), pl.BlockSpec((N, LORA_BLOCK, DIM_BLOCK), - lambda t, l, d, block_idx: (0, l, d)), + lambda i, j, k, block_idx: (0, j, k)), ], out_specs=pl.BlockSpec((TOKEN_BLOCK, LORA_BLOCK), - lambda t, l, d, block_idx: (t, l)), + lambda i, j, k, block_idx: (i, j)), scratch_shapes=[ pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32) @@ -79,118 +77,16 @@ def _bgmv( name="bgmv")(idxs, inputs, loras) -def _fused_bgmv_shrink_expand_kernel( - bT: int, bL: int, bD: int, max_num_loras: int, - idx_ref, inp_ref, lora_a_ref, lora_b_ref, out_ref, cache_ref, acc_ref, mask_a_ref, mask_b_ref): - - t = pl.program_id(0) - d1 = pl.program_id(1) - l = pl.program_id(2) - d2 = pl.program_id(3) - - ls = pl.num_programs(2) - ds = pl.num_programs(3) - - should_compute_cache = d1 == 0 - @pl.when(should_compute_cache) - def _(): - @pl.when(d2 == 0) - def _(): - cache_ref[l, ...] = jnp.zeros_like(cache_ref[l, ...], dtype=jnp.float32) - - for i in range(max_num_loras): - mask_a_ref[...] = jnp.zeros_like(mask_a_ref[...], dtype=jnp.float32) - valid = False - for j in range(bT): - valid |= idx_ref[j + bT * t] == i - - @pl.when(idx_ref[j + bT * t] == i) - def _(): - mask_a_ref[j, :] = jnp.ones((bL, ), dtype=jnp.float32) - - @pl.when(valid) - def _(): - cache_ref[l, ...] += jax.lax.dot_general( - inp_ref[...], - lora_a_ref[i, ...], (((1, ), (1, )), ((), ())), - preferred_element_type=jnp.float32) * mask_a_ref[...] - - cache_valid = d2 == (ds - 1) - @pl.when(cache_valid) - def _(): - @pl.when(l == 0) - def _(): - acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) - - for i in range(max_num_loras): - mask_b_ref[...] = jnp.zeros_like(mask_b_ref[...], dtype=jnp.float32) - valid = False - for j in range(bT): - valid |= idx_ref[j + bT * t] == i - - @pl.when(idx_ref[j + bT * t] == i) - def _(): - mask_b_ref[j, :] = jnp.ones((bD, ), dtype=jnp.float32) - - @pl.when(valid) - def _(): - acc_ref[...] += jax.lax.dot_general( - cache_ref[l, ...], - lora_b_ref[i, ...], (((1, ), (1, )), ((), ())), - preferred_element_type=jnp.float32) * mask_b_ref[...] - - @pl.when(l == ls - 1) - def _(): - out_ref[...] = acc_ref[...].astype(out_ref.dtype) - -@functools.partial(jax.jit, static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"]) -def _fused_bgmv_shrink_expand( - idxs: jax.Array, # (T, ) int32 - inputs: jax.Array, # (T, D) model dtype - loras_a: jax.Array, # (N, L, D) model dtype - loras_b: jax.Array, # (N, D, L) model dtype - *, - TOKEN_BLOCK: int, - LORA_BLOCK: int, - DIM_BLOCK: int -) -> jax.Array: # (T, D) model dtype - T, D = inputs.shape - N, L, _ = loras_a.shape - - return pl.pallas_call( - kernel=functools.partial(_fused_bgmv_shrink_expand_kernel, TOKEN_BLOCK, LORA_BLOCK, DIM_BLOCK, N), - out_shape=jax.ShapeDtypeStruct((T, D), dtype=inputs.dtype), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - grid=(T // TOKEN_BLOCK, D // DIM_BLOCK, L // LORA_BLOCK, D // DIM_BLOCK), - in_specs=[ - pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), lambda t, d1, l, d2, block_idx: (t, d2)), # Inputs - pl.BlockSpec((N, LORA_BLOCK, DIM_BLOCK), lambda t, d1, l, d2, block_idx: (0, l, d2)), # LoRA A - pl.BlockSpec((N, DIM_BLOCK, LORA_BLOCK), lambda t, d1, l, d2, block_idx: (0, d1, l)), # LoRA B - ], - out_specs=pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), lambda t, d1, l, d2, block_idx: (t, d1)), - scratch_shapes=[ - pltpu.VMEM((L // LORA_BLOCK, TOKEN_BLOCK, LORA_BLOCK), jnp.float32), # Intermediates cache - pltpu.VMEM((TOKEN_BLOCK, DIM_BLOCK), jnp.float32), # Final accumulator - pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), # LoRA A mask - pltpu.VMEM((TOKEN_BLOCK, DIM_BLOCK), jnp.float32) # LoRA B mask - ] - ), - compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=("parallel", "parallel", "arbitrary", "arbitrary") - ), - name="fused_bgmv_shrink_expand" - )(idxs, inputs, loras_a, loras_b) - - -XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") - def bgmv_shape_function(idxs, inputs, loras): T, _ = inputs.shape _, L, _ = loras.shape return [((T, L), inputs.dtype)] + +XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", ) + + @impl(XLA_LIB, "bgmv", "XLA") def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): inputs = inputs.to(dtype=loras.dtype) @@ -257,67 +153,3 @@ def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, _, L, _ = loras.shape return torch.empty((T, L), device=inputs.device) - -def fused_bgmv_shrink_expand_shape_function(idxs, inputs, loras_a, loras_b): - return [(inputs.shape, inputs.dtype)] - -XLA_LIB.define("fused_bgmv_shrink_expand(Tensor inputs, Tensor loras_a, Tensor loras_b, Tensor idxs) -> Tensor") - -@impl(XLA_LIB, "fused_bgmv_shrink_expand", "XLA") -def fused_bgmv_shrink_expand_xla(inputs: torch.Tensor, loras_a: torch.Tensor, loras_b: torch.Tensor, idxs: torch.IntTensor): - inputs = inputs.to(dtype=loras_a.dtype) - - if len(loras_a.shape) == 4: - loras_a = loras_a.squeeze(axis=1) - if len(loras_b.shape) == 4: - loras_b = loras_b.squeeze(axis=1) - - T, _ = inputs.shape - _, L, D = loras_a.shape - - jax_import_guard() - - TOKEN_BLOCK=16 - LORA_BLOCK=256 - DIM_BLOCK=1024 - - - kernel = make_kernel_from_pallas( - functools.partial( - _fused_bgmv_shrink_expand, - TOKEN_BLOCK=TOKEN_BLOCK, - LORA_BLOCK=LORA_BLOCK, - DIM_BLOCK=DIM_BLOCK - ), - fused_bgmv_shrink_expand_shape_function - ) - - # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU - # register. This has to happen in pytorch, doing it in Jax will lead to NaNs - L1 = L - if LORA_BLOCK > L or L % LORA_BLOCK != 0: - L1 = (L // LORA_BLOCK + 1) * LORA_BLOCK - - D1 = D - if DIM_BLOCK > D or D % DIM_BLOCK != 0: - D1 = (D // DIM_BLOCK + 1) * DIM_BLOCK - - T1 = T - if TOKEN_BLOCK > T or T % TOKEN_BLOCK != 0: - T1 = (T // TOKEN_BLOCK + 1) * TOKEN_BLOCK - - if D1 != D or L1 != L: - loras_a = torch.nn.functional.pad(loras_a, (0, D1 - D, 0, L1 - L, 0, 0)) - loras_b = torch.nn.functional.pad(loras_b, (0, L1 - L, 0, D1 - D, 0, 0)) - if D1 != D or T1 != T: - inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T)) - if T1 != T: - idxs = torch.nn.functional.pad(idxs, ((0, T1 - T))) - - return kernel(idxs, inputs, loras_a, loras_b)[:T, :D] - - -@impl(XLA_LIB, "fused_bgmv_shrink_expand", "CompositeExplicitAutograd") -def fused_bgmv_shrink_expand_non_xla(inputs: torch.Tensor, loras_a: torch.Tensor, loras_b: torch.Tensor, - idxs: torch.IntTensor): - return torch.empty_like(inputs) \ No newline at end of file From 54344b7f56858e90ae5f91da8635050b60fa9b02 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Mar 2025 20:36:28 +0000 Subject: [PATCH 104/186] Added some autotuning for kernels Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index fee8c30ba1b3..d5af38ef5a35 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -17,13 +17,13 @@ def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) t = pl.program_id(0) - + for i in range(max_num_loras): mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) valid = False for j in range(bT): valid |= idx_ref[j + bT * t] == i - + @pl.when(idx_ref[j + bT * t] == i) def _(): mask_ref[j, :] = jnp.ones((bL, ), dtype=jnp.float32) @@ -99,14 +99,17 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): jax_import_guard() - TOKEN_BLOCK=16 + TOKEN_BLOCK = 16 if L > D: # Expand - LORA_BLOCK=1024 - DIM_BLOCK=256 + LORA_BLOCK = 1024 + DIM_BLOCK = 256 else: # Shrink - LORA_BLOCK=256 - DIM_BLOCK=1024 + LORA_BLOCK = 256 + DIM_BLOCK = 1024 + TOKEN_BLOCK = min(max(TOKEN_BLOCK, pl.next_power_of_2(T)), 128) + LORA_BLOCK = max(LORA_BLOCK, pl.next_power_of_2(L)) + DIM_BLOCK = max(DIM_BLOCK, pl.next_power_of_2(D)) kernel = make_kernel_from_pallas( functools.partial( From c5a42e20378db77c0d673229d516f525e751b8bc Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Mar 2025 22:49:33 +0000 Subject: [PATCH 105/186] Renamed padding variables Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index d5af38ef5a35..52c5ebdfa756 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -123,24 +123,24 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU # register. This has to happen in pytorch, doing it in Jax will lead to NaNs - L1 = L + pad_L = 0 if LORA_BLOCK > L or L % LORA_BLOCK != 0: - L1 = (L // LORA_BLOCK + 1) * LORA_BLOCK + pad_L = (L // LORA_BLOCK + 1) * LORA_BLOCK - L - D1 = D + pad_D = 0 if DIM_BLOCK > D or D % DIM_BLOCK != 0: - D1 = (D // DIM_BLOCK + 1) * DIM_BLOCK + pad_D = (D // DIM_BLOCK + 1) * DIM_BLOCK - D - T1 = T + pad_T = 0 if TOKEN_BLOCK > T or T % TOKEN_BLOCK != 0: - T1 = (T // TOKEN_BLOCK + 1) * TOKEN_BLOCK - - if D1 != D or L1 != L: - loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0)) - if D1 != D or T1 != T: - inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T)) - if T1 != T: - idxs = torch.nn.functional.pad(idxs, ((0, T1 - T))) + pad_T = (T // TOKEN_BLOCK + 1) * TOKEN_BLOCK - T + + if pad_D != 0 or pad_L != 0: + loras = torch.nn.functional.pad(loras, (0, pad_D, 0, pad_L, 0, 0)) + if pad_D != 0 or pad_T != 0: + inputs = torch.nn.functional.pad(inputs, (0, pad_D, 0, pad_T)) + if pad_T != T: + idxs = torch.nn.functional.pad(idxs, ((0, pad_T))) return kernel(idxs, inputs, loras)[:T, :L] From e66067c73ab1fd951475e43efdba7d72c81664a4 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 18 Mar 2025 23:36:28 +0000 Subject: [PATCH 106/186] Used a static ones vector, gives a 5%ish perf boost Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 52c5ebdfa756..8ad0a0822d33 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -17,6 +17,8 @@ def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) t = pl.program_id(0) + + ones = jnp.ones((bL, ), dtype=jnp.float32) for i in range(max_num_loras): mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) @@ -26,7 +28,7 @@ def _(): @pl.when(idx_ref[j + bT * t] == i) def _(): - mask_ref[j, :] = jnp.ones((bL, ), dtype=jnp.float32) + mask_ref.at[j, :].set(ones) @pl.when(valid) def _(): @@ -96,11 +98,12 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): T, _ = inputs.shape _, L, D = loras.shape + is_expand = L > D jax_import_guard() TOKEN_BLOCK = 16 - if L > D: # Expand + if is_expand: # Expand LORA_BLOCK = 1024 DIM_BLOCK = 256 else: # Shrink From 7c79683e6b89841ee91e8423a135e702de54c545 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 19 Mar 2025 21:43:11 +0000 Subject: [PATCH 107/186] Restricted block sizes to prevent memory from blowing up Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 8ad0a0822d33..6bb7302c3db6 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -111,8 +111,8 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): DIM_BLOCK = 1024 TOKEN_BLOCK = min(max(TOKEN_BLOCK, pl.next_power_of_2(T)), 128) - LORA_BLOCK = max(LORA_BLOCK, pl.next_power_of_2(L)) - DIM_BLOCK = max(DIM_BLOCK, pl.next_power_of_2(D)) + LORA_BLOCK = min(max(LORA_BLOCK, pl.next_power_of_2(L)), 4096) + DIM_BLOCK = min(max(DIM_BLOCK, pl.next_power_of_2(D)), 4096) kernel = make_kernel_from_pallas( functools.partial( From d7338f8085cb65d3519ba1b8d0f0964251ace31f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 19 Mar 2025 22:21:11 +0000 Subject: [PATCH 108/186] Removed larger lora/dim block sizes since they reduce perf outside of microbenchmarks Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 6bb7302c3db6..87029550c72a 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -102,7 +102,7 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): jax_import_guard() - TOKEN_BLOCK = 16 + TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) if is_expand: # Expand LORA_BLOCK = 1024 DIM_BLOCK = 256 @@ -110,10 +110,6 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): LORA_BLOCK = 256 DIM_BLOCK = 1024 - TOKEN_BLOCK = min(max(TOKEN_BLOCK, pl.next_power_of_2(T)), 128) - LORA_BLOCK = min(max(LORA_BLOCK, pl.next_power_of_2(L)), 4096) - DIM_BLOCK = min(max(DIM_BLOCK, pl.next_power_of_2(D)), 4096) - kernel = make_kernel_from_pallas( functools.partial( _bgmv, @@ -128,15 +124,15 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): # register. This has to happen in pytorch, doing it in Jax will lead to NaNs pad_L = 0 if LORA_BLOCK > L or L % LORA_BLOCK != 0: - pad_L = (L // LORA_BLOCK + 1) * LORA_BLOCK - L + pad_L = next_multiple_of(L, LORA_BLOCK) - L pad_D = 0 if DIM_BLOCK > D or D % DIM_BLOCK != 0: - pad_D = (D // DIM_BLOCK + 1) * DIM_BLOCK - D + pad_D = next_multiple_of(D, DIM_BLOCK) - D pad_T = 0 if TOKEN_BLOCK > T or T % TOKEN_BLOCK != 0: - pad_T = (T // TOKEN_BLOCK + 1) * TOKEN_BLOCK - T + pad_T = next_multiple_of(T, TOKEN_BLOCK) - T if pad_D != 0 or pad_L != 0: loras = torch.nn.functional.pad(loras, (0, pad_D, 0, pad_L, 0, 0)) @@ -159,3 +155,12 @@ def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, _, L, _ = loras.shape return torch.empty((T, L), device=inputs.device) + + +def next_multiple_of(n: int, mult: int) -> int: + if n % mult == 0: + return n + return (n // mult + 1) * mult + +def get_bounded_value(_min: int, val: int, _max: int) -> int: + return min(max(_min, val), _max) \ No newline at end of file From 2bb886899f540463a43c47ec81ae77e3d1dd105c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 19 Mar 2025 22:53:29 +0000 Subject: [PATCH 109/186] Allowed smaller LoRA blocks if necessary Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 87029550c72a..ab6ffdf2b740 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -104,11 +104,11 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) if is_expand: # Expand - LORA_BLOCK = 1024 + LORA_BLOCK = min(1024, next_multiple_of(L, 256)) DIM_BLOCK = 256 else: # Shrink LORA_BLOCK = 256 - DIM_BLOCK = 1024 + DIM_BLOCK = min(1024, next_multiple_of(D, 256)) kernel = make_kernel_from_pallas( functools.partial( From 27ad7938a0556e8df569953fa9195fdc588b4ddd Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 20 Mar 2025 16:09:44 +0000 Subject: [PATCH 110/186] Replaced torch.cat operations with F.pad Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index f1b642e5852c..81c00f22dd5b 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch +import torch.nn.functional as F # Required to register the custom ops import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import @@ -12,22 +13,21 @@ def bgmv_expand(inputs: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - n_tokens = outputs.size(0) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: limit = 1 - outputs = torch.cat( - (outputs, - torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]), - device=outputs.device)), - dim=1) + if output_tensor.shape[1] > outputs.shape[1]: + outputs = F.pad( + outputs, + (0, output_tensor.shape[1] - outputs.shape[1], 0, 0) + ) if add_inputs: - return output_tensor + outputs[:limit, :] + return output_tensor + outputs[:limit, :output_tensor.shape[1]] else: - return outputs[:limit, :] + return outputs[:limit, :output_tensor.shape[1]] def bgmv_shrink(inputs: torch.Tensor, @@ -48,16 +48,11 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_size: int, add_inputs: bool = True): outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - n_tokens = outputs.size(0) - outputs = torch.cat(( - torch.zeros((n_tokens, slice_offset), device=outputs.device), + outputs = F.pad( outputs, - torch.zeros( - (n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)), - device=outputs.device), - ), - dim=1) + (slice_offset, output_tensor.shape[1] - (slice_offset + slice_size), 0, 0) + ) if add_inputs: return output_tensor + outputs From a82f3fe038c77d43d1ec7713fccf1d7fc9412ec0 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 20 Mar 2025 17:54:00 +0000 Subject: [PATCH 111/186] Added fused lora transpose [experimental] Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 5 +- vllm/lora/layers.py | 7 +- vllm/lora/ops/xla_ops/lora_ops.py | 23 +-- vllm/lora/ops/xla_ops/pallas.py | 208 +++++++++++++++++++++---- vllm/lora/punica_wrapper/punica_tpu.py | 8 +- 5 files changed, 207 insertions(+), 44 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index 5be0277b82c6..63b1c3d87ab3 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -74,7 +74,10 @@ def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): T, D, L, N, seed, dtype) # Run bgmv - output = torch.ops.xla.bgmv(inputs, loras, idxs) + if op_type == "shrink": + output = torch.ops.xla.bgmv(inputs, loras, idxs) + else: + output = torch.ops.xla.bgmv_expand(inputs, loras.transpose(2, 3), idxs) # Make sure we have no NaNs assert not torch.any(torch.isnan(output)) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index bb2dd63c8b62..94c57dc9fe05 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1049,6 +1049,9 @@ def create_lora_weights( dtype=lora_config.lora_dtype, device=self.device, ) + + self.lora_b_stacked = torch.transpose(self.lora_b_stacked, 2, 3) + self.embeddings_tensors = torch.full( (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), fill_value=float("-inf"), @@ -1081,8 +1084,8 @@ def set_lora( 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( lora_a.T, non_blocking=True) self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) + 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( + lora_b, non_blocking=True) if embeddings_tensor is not None: self.embeddings_tensors[ index, diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 81c00f22dd5b..f9b93de908f5 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -11,18 +11,23 @@ def bgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, - add_inputs: bool = True): - outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + add_inputs: bool = True, + fused_transpose: bool = False): + + if fused_transpose: + outputs = torch.ops.xla.bgmv_pre_transpose(inputs, lora_b_weights, + lora_indices_tensor) + else: + outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, + lora_indices_tensor) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: limit = 1 if output_tensor.shape[1] > outputs.shape[1]: - outputs = F.pad( - outputs, - (0, output_tensor.shape[1] - outputs.shape[1], 0, 0) - ) + outputs = F.pad(outputs, + (0, output_tensor.shape[1] - outputs.shape[1], 0, 0)) if add_inputs: return output_tensor + outputs[:limit, :output_tensor.shape[1]] @@ -49,10 +54,8 @@ def bgmv_expand_slice(inputs: torch.Tensor, add_inputs: bool = True): outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - outputs = F.pad( - outputs, - (slice_offset, output_tensor.shape[1] - (slice_offset + slice_size), 0, 0) - ) + outputs = F.pad(outputs, (slice_offset, output_tensor.shape[1] - + (slice_offset + slice_size), 0, 0)) if add_inputs: return output_tensor + outputs diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index ab6ffdf2b740..aa28d1d11a9e 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +import math +from typing import List import jax import jax.numpy as jnp @@ -10,14 +12,16 @@ from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, make_kernel_from_pallas) -def _bgmv_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, out_ref, - acc_ref, mask_ref): + +def _bgmv_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, + lora_ref, out_ref, acc_ref, mask_ref): + @pl.when(pl.program_id(2) == 0) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) t = pl.program_id(0) - + ones = jnp.ones((bL, ), dtype=jnp.float32) for i in range(max_num_loras): @@ -42,16 +46,16 @@ def _(): out_ref[...] = acc_ref[...].astype(out_ref.dtype) -@functools.partial(jax.jit, static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"]) +@functools.partial(jax.jit, + static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"]) def _bgmv( - idxs: jax.Array, # (T, ) int32 - inputs: jax.Array, # (T, D) model dtype - loras: jax.Array, # (N, L, D) model dtype - *, - TOKEN_BLOCK: int, - LORA_BLOCK: int, - DIM_BLOCK: int -) -> jax.Array: # (T, L) model dtype + idxs: jax.Array, # (T, ) int32 + inputs: jax.Array, # (T, D) model dtype + loras: jax.Array, # (N, L, D) model dtype + *, + TOKEN_BLOCK: int, + LORA_BLOCK: int, + DIM_BLOCK: int) -> jax.Array: # (T, L) model dtype T, D = inputs.shape N, L, _ = loras.shape @@ -60,8 +64,7 @@ def _bgmv( out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, - grid=(T // TOKEN_BLOCK, L // LORA_BLOCK, - D // DIM_BLOCK), + grid=(T // TOKEN_BLOCK, L // LORA_BLOCK, D // DIM_BLOCK), in_specs=[ pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), lambda i, j, k, block_idx: (i, k)), @@ -103,22 +106,18 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): jax_import_guard() TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) - if is_expand: # Expand + if is_expand: # Expand LORA_BLOCK = min(1024, next_multiple_of(L, 256)) DIM_BLOCK = 256 - else: # Shrink + else: # Shrink LORA_BLOCK = 256 DIM_BLOCK = min(1024, next_multiple_of(D, 256)) kernel = make_kernel_from_pallas( - functools.partial( - _bgmv, - TOKEN_BLOCK=TOKEN_BLOCK, - LORA_BLOCK=LORA_BLOCK, - DIM_BLOCK=DIM_BLOCK - ), - bgmv_shape_function - ) + functools.partial(_bgmv, + TOKEN_BLOCK=TOKEN_BLOCK, + LORA_BLOCK=LORA_BLOCK, + DIM_BLOCK=DIM_BLOCK), bgmv_shape_function) # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU # register. This has to happen in pytorch, doing it in Jax will lead to NaNs @@ -157,10 +156,163 @@ def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, return torch.empty((T, L), device=inputs.device) +# This kernel is similar to the one above but it assumes that the LoRA adapters +# have been pre-transposed. This lets us skip the data copies involved in +# transposing. +# We only need this for the expand op since the LoRA dimensions in the shrink op +# are small enough that the TPU can gather them without a data copy. +def _bgmv_pre_transpose_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, + inp_ref, lora_ref, out_ref, acc_ref, mask_ref): + + @pl.when(pl.program_id(2) == 0) + def _(): + acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) + + t = pl.program_id(0) + + ones = jnp.ones((bL, ), dtype=jnp.float32) + + for i in range(max_num_loras): + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + valid = False + for j in range(bT): + valid |= idx_ref[j + bT * t] == i + + @pl.when(idx_ref[j + bT * t] == i) + def _(): + mask_ref.at[j, :].set(ones) + + @pl.when(valid) + def _(): + acc_ref[...] += jax.lax.dot( + inp_ref[...], + lora_ref[i, ...], + preferred_element_type=jnp.float32) * mask_ref[...] + + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + def _(): + out_ref[...] = acc_ref[...].astype(out_ref.dtype) + + +@functools.partial(jax.jit, + static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"]) +def _bgmv_pre_transpose( + idxs: jax.Array, # (T, ) int32 + inputs: jax.Array, # (T, D) model dtype + loras: jax.Array, # (N, L, D) model dtype + *, + TOKEN_BLOCK: int, + LORA_BLOCK: int, + DIM_BLOCK: int) -> jax.Array: # (T, L) model dtype + T, D = inputs.shape + N, _, L = loras.shape + + return pl.pallas_call( + kernel=functools.partial(_bgmv_pre_transpose_kernel, TOKEN_BLOCK, + LORA_BLOCK, N), + out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T // TOKEN_BLOCK, L // LORA_BLOCK, D // DIM_BLOCK), + in_specs=[ + pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), + lambda i, j, k, block_idx: (i, k)), + pl.BlockSpec((N, DIM_BLOCK, LORA_BLOCK), + lambda i, j, k, block_idx: (0, k, j)), + ], + out_specs=pl.BlockSpec((TOKEN_BLOCK, LORA_BLOCK), + lambda i, j, k, block_idx: (i, j)), + scratch_shapes=[ + pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), + pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32) + ]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), + name="bgmv_pre_transpose")(idxs, inputs, loras) + + +def bgmv_pre_transpose_shape_function(idxs, inputs, loras): + T, _ = inputs.shape + _, _, L = loras.shape + + return [((T, L), inputs.dtype)] + + +XLA_LIB.define( + "bgmv_pre_transpose(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", ) + + +@impl(XLA_LIB, "bgmv_pre_transpose", "XLA") +def bgmv_pre_transpose_xla(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): + inputs = inputs.to(dtype=loras.dtype) + + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + T, _ = inputs.shape + _, D, L = loras.shape + + jax_import_guard() + + TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) + LORA_BLOCK = min(1024, next_multiple_of(L, 256)) + DIM_BLOCK = 256 + + kernel = make_kernel_from_pallas( + functools.partial(_bgmv_pre_transpose, + TOKEN_BLOCK=TOKEN_BLOCK, + LORA_BLOCK=LORA_BLOCK, + DIM_BLOCK=DIM_BLOCK), + bgmv_pre_transpose_shape_function) + + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU + # register. This has to happen in pytorch, doing it in Jax will lead to NaNs + pad_L = 0 + if LORA_BLOCK > L or L % LORA_BLOCK != 0: + pad_L = next_multiple_of(L, LORA_BLOCK) - L + + pad_D = 0 + if DIM_BLOCK > D or D % DIM_BLOCK != 0: + pad_D = next_multiple_of(D, DIM_BLOCK) - D + + pad_T = 0 + if TOKEN_BLOCK > T or T % TOKEN_BLOCK != 0: + pad_T = next_multiple_of(T, TOKEN_BLOCK) - T + + if pad_D != 0 or pad_L != 0: + loras = torch.nn.functional.pad(loras, (0, pad_L, 0, pad_D, 0, 0)) + if pad_D != 0 or pad_T != 0: + inputs = torch.nn.functional.pad(inputs, (0, pad_D, 0, pad_T)) + if pad_T != T: + idxs = torch.nn.functional.pad(idxs, ((0, pad_T))) + + return kernel(idxs, inputs, loras)[:T, :L] + + +@impl(XLA_LIB, "bgmv_pre_transpose", "CompositeExplicitAutograd") +def bgmv_pre_transpose_non_xla(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): + T, _ = inputs.shape + + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + _, _, L = loras.shape + + return torch.empty((T, L), device=inputs.device) + + +def largest_divisor(n: int, divs: List[int]) -> int: + for div in sorted(divs, reverse=True): + if n % div == 0: + return div + return max(divs) + + def next_multiple_of(n: int, mult: int) -> int: - if n % mult == 0: - return n - return (n // mult + 1) * mult + return math.ceil(n / mult) * mult + def get_bounded_value(_min: int, val: int, _max: int) -> int: - return min(max(_min, val), _max) \ No newline at end of file + return min(max(_min, val), _max) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 38b4b393be8b..9ebbbf2ec6ff 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -277,11 +277,12 @@ def add_lora_logits(self, y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) - r = lora_b_stacked.size(-1) + + rank = lora_b_stacked.size(-1) if buffer is None: # We set the buffer to be float32 by default, consistent with the # triton op - buffer = torch.zeros((x.size(0), r), + buffer = torch.zeros((x.size(0), rank), dtype=torch.float32, device=x.device) @@ -291,7 +292,8 @@ def add_lora_logits(self, lora_b_stacked, y, self.sampler_indices, - add_inputs=True) + add_inputs=True, + fused_transpose=True) return y.view_as(y_org) def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: From de6746ac3efb478890bf73735bd0c1beaf589646 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 20 Mar 2025 19:16:06 +0000 Subject: [PATCH 112/186] Separated bgmv_shrink and bgmv_expand kernels to avoid unneccessary data copies during transpose Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 2 +- vllm/lora/layers.py | 7 +--- vllm/lora/ops/xla_ops/lora_ops.py | 15 +++---- vllm/lora/ops/xla_ops/pallas.py | 57 +++++++++++--------------- vllm/lora/punica_wrapper/punica_tpu.py | 3 +- 5 files changed, 32 insertions(+), 52 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index 63b1c3d87ab3..1098f8e3d47e 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -75,7 +75,7 @@ def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): # Run bgmv if op_type == "shrink": - output = torch.ops.xla.bgmv(inputs, loras, idxs) + output = torch.ops.xla.bgmv_shrink(inputs, loras, idxs) else: output = torch.ops.xla.bgmv_expand(inputs, loras.transpose(2, 3), idxs) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 94c57dc9fe05..bb2dd63c8b62 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1049,9 +1049,6 @@ def create_lora_weights( dtype=lora_config.lora_dtype, device=self.device, ) - - self.lora_b_stacked = torch.transpose(self.lora_b_stacked, 2, 3) - self.embeddings_tensors = torch.full( (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), fill_value=float("-inf"), @@ -1084,8 +1081,8 @@ def set_lora( 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( lora_a.T, non_blocking=True) self.lora_b_stacked[index, - 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( - lora_b, non_blocking=True) + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) if embeddings_tensor is not None: self.embeddings_tensors[ index, diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index f9b93de908f5..750def2d5d32 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -11,15 +11,10 @@ def bgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, - add_inputs: bool = True, - fused_transpose: bool = False): + add_inputs: bool = True): - if fused_transpose: - outputs = torch.ops.xla.bgmv_pre_transpose(inputs, lora_b_weights, - lora_indices_tensor) - else: - outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, - lora_indices_tensor) + outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3), + lora_indices_tensor) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -41,7 +36,7 @@ def bgmv_shrink(inputs: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): - return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, + return scaling * torch.ops.xla.bgmv_shrink(inputs, lora_b_weights, lora_indices_tensor) @@ -52,7 +47,7 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): - outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3), lora_indices_tensor) outputs = F.pad(outputs, (slice_offset, output_tensor.shape[1] - (slice_offset + slice_size), 0, 0)) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index aa28d1d11a9e..42098b48ae97 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -12,8 +12,11 @@ from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, make_kernel_from_pallas) +XLA_LIB.define("bgmv_shrink(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") +XLA_LIB.define( + "bgmv_expand(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") -def _bgmv_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, +def _bgmv_shrink_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, mask_ref): @pl.when(pl.program_id(2) == 0) @@ -48,7 +51,7 @@ def _(): @functools.partial(jax.jit, static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"]) -def _bgmv( +def _bgmv_shrink( idxs: jax.Array, # (T, ) int32 inputs: jax.Array, # (T, D) model dtype loras: jax.Array, # (N, L, D) model dtype @@ -60,7 +63,7 @@ def _bgmv( N, L, _ = loras.shape return pl.pallas_call( - kernel=functools.partial(_bgmv_kernel, TOKEN_BLOCK, LORA_BLOCK, N), + kernel=functools.partial(_bgmv_shrink_kernel, TOKEN_BLOCK, LORA_BLOCK, N), out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, @@ -81,19 +84,15 @@ def _bgmv( dimension_semantics=("parallel", "parallel", "arbitrary")), name="bgmv")(idxs, inputs, loras) - -def bgmv_shape_function(idxs, inputs, loras): +def bgmv_shrink_shape_function(idxs, inputs, loras): T, _ = inputs.shape _, L, _ = loras.shape return [((T, L), inputs.dtype)] -XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", ) - - -@impl(XLA_LIB, "bgmv", "XLA") -def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): +@impl(XLA_LIB, "bgmv_shrink", "XLA") +def bgmv_shrink_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): inputs = inputs.to(dtype=loras.dtype) if len(loras.shape) == 4: @@ -114,10 +113,10 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): DIM_BLOCK = min(1024, next_multiple_of(D, 256)) kernel = make_kernel_from_pallas( - functools.partial(_bgmv, + functools.partial(_bgmv_shrink, TOKEN_BLOCK=TOKEN_BLOCK, LORA_BLOCK=LORA_BLOCK, - DIM_BLOCK=DIM_BLOCK), bgmv_shape_function) + DIM_BLOCK=DIM_BLOCK), bgmv_shrink_shape_function) # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU # register. This has to happen in pytorch, doing it in Jax will lead to NaNs @@ -142,9 +141,8 @@ def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): return kernel(idxs, inputs, loras)[:T, :L] - -@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") -def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, +@impl(XLA_LIB, "bgmv_shrink", "CompositeExplicitAutograd") +def bgmv_shrink_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): T, _ = inputs.shape @@ -161,7 +159,7 @@ def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, # transposing. # We only need this for the expand op since the LoRA dimensions in the shrink op # are small enough that the TPU can gather them without a data copy. -def _bgmv_pre_transpose_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, +def _bgmv_expand_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, mask_ref): @pl.when(pl.program_id(2) == 0) @@ -196,7 +194,7 @@ def _(): @functools.partial(jax.jit, static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"]) -def _bgmv_pre_transpose( +def _bgmv_expand( idxs: jax.Array, # (T, ) int32 inputs: jax.Array, # (T, D) model dtype loras: jax.Array, # (N, L, D) model dtype @@ -208,7 +206,7 @@ def _bgmv_pre_transpose( N, _, L = loras.shape return pl.pallas_call( - kernel=functools.partial(_bgmv_pre_transpose_kernel, TOKEN_BLOCK, + kernel=functools.partial(_bgmv_expand_kernel, TOKEN_BLOCK, LORA_BLOCK, N), out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( @@ -231,19 +229,14 @@ def _bgmv_pre_transpose( name="bgmv_pre_transpose")(idxs, inputs, loras) -def bgmv_pre_transpose_shape_function(idxs, inputs, loras): +def bgmv_expand_shape_function(idxs, inputs, loras): T, _ = inputs.shape _, _, L = loras.shape return [((T, L), inputs.dtype)] - -XLA_LIB.define( - "bgmv_pre_transpose(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", ) - - -@impl(XLA_LIB, "bgmv_pre_transpose", "XLA") -def bgmv_pre_transpose_xla(inputs: torch.Tensor, loras: torch.Tensor, +@impl(XLA_LIB, "bgmv_expand", "XLA") +def bgmv_expand_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): inputs = inputs.to(dtype=loras.dtype) @@ -260,11 +253,11 @@ def bgmv_pre_transpose_xla(inputs: torch.Tensor, loras: torch.Tensor, DIM_BLOCK = 256 kernel = make_kernel_from_pallas( - functools.partial(_bgmv_pre_transpose, + functools.partial(_bgmv_expand, TOKEN_BLOCK=TOKEN_BLOCK, LORA_BLOCK=LORA_BLOCK, DIM_BLOCK=DIM_BLOCK), - bgmv_pre_transpose_shape_function) + bgmv_expand_shape_function) # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU # register. This has to happen in pytorch, doing it in Jax will lead to NaNs @@ -289,9 +282,8 @@ def bgmv_pre_transpose_xla(inputs: torch.Tensor, loras: torch.Tensor, return kernel(idxs, inputs, loras)[:T, :L] - -@impl(XLA_LIB, "bgmv_pre_transpose", "CompositeExplicitAutograd") -def bgmv_pre_transpose_non_xla(inputs: torch.Tensor, loras: torch.Tensor, +@impl(XLA_LIB, "bgmv_expand", "CompositeExplicitAutograd") +def bgmv_expand_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): T, _ = inputs.shape @@ -302,17 +294,14 @@ def bgmv_pre_transpose_non_xla(inputs: torch.Tensor, loras: torch.Tensor, return torch.empty((T, L), device=inputs.device) - def largest_divisor(n: int, divs: List[int]) -> int: for div in sorted(divs, reverse=True): if n % div == 0: return div return max(divs) - def next_multiple_of(n: int, mult: int) -> int: return math.ceil(n / mult) * mult - def get_bounded_value(_min: int, val: int, _max: int) -> int: return min(max(_min, val), _max) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 9ebbbf2ec6ff..c821079dc829 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -292,8 +292,7 @@ def add_lora_logits(self, lora_b_stacked, y, self.sampler_indices, - add_inputs=True, - fused_transpose=True) + add_inputs=True) return y.view_as(y_org) def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: From 19b908988d1671884e3c7ab59ba6a60a6d8a82af Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 20 Mar 2025 21:39:06 +0000 Subject: [PATCH 113/186] Removed redundant branch Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 42098b48ae97..7a70595edf23 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -105,12 +105,8 @@ def bgmv_shrink_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTe jax_import_guard() TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) - if is_expand: # Expand - LORA_BLOCK = min(1024, next_multiple_of(L, 256)) - DIM_BLOCK = 256 - else: # Shrink - LORA_BLOCK = 256 - DIM_BLOCK = min(1024, next_multiple_of(D, 256)) + LORA_BLOCK = 256 + DIM_BLOCK = min(1024, next_multiple_of(D, 256)) kernel = make_kernel_from_pallas( functools.partial(_bgmv_shrink, From e07d6fb73d4022978529a0b500da34f8dd8dabf1 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Mar 2025 19:55:19 +0000 Subject: [PATCH 114/186] Moved punica related `mark_dynamic` to the TPUModelRunner to allow the `enforce_eager` to work Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 2 -- vllm/v1/worker/tpu_model_runner.py | 4 ++++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 38b4b393be8b..e5cd6d6b1fdf 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -31,8 +31,6 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) self._sampler_indices_padded = self._sampler_indices_padded.to( dtype=torch.int32) - torch._dynamo.mark_dynamic(self._embeddings_indices, 1) - torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) @property def embeddings_indices(self) -> torch.Tensor: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 0e499107ff5c..523388715dd3 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -775,6 +775,10 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + + punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper + torch._dynamo.mark_dynamic(punica_wrapper._embeddings_indices, 1) + torch._dynamo.mark_dynamic(punica_wrapper._sampler_indices_padded, 0) with set_forward_context(attn_metadata, self.vllm_config, 0): self.model(input_ids=input_ids, From 5b4ba1ba4e4a978230bffbd03d0ead4fa51ce398 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Mar 2025 19:57:02 +0000 Subject: [PATCH 115/186] Moved `maybe_dummy_run_with_lora` to the `_dummy_run` method Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 523388715dd3..bd35efbb1474 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -780,11 +780,13 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: torch._dynamo.mark_dynamic(punica_wrapper._embeddings_indices, 1) torch._dynamo.mark_dynamic(punica_wrapper._sampler_indices_padded, 0) - with set_forward_context(attn_metadata, self.vllm_config, 0): - self.model(input_ids=input_ids, - positions=position_ids, - kv_caches=kv_caches, - inputs_embeds=inputs_embeds) + with self.maybe_dummy_run_with_lora( + self.lora_config, np.array([num_tokens], dtype=np.int32)): + with set_forward_context(attn_metadata, self.vllm_config, 0): + self.model(input_ids=input_ids, + positions=position_ids, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds) def capture_model(self) -> None: """Compile the model.""" @@ -795,9 +797,7 @@ def capture_model(self) -> None: num_tokens = 16 while True: logger.info(" -- num_tokens: %d", num_tokens) - with self.maybe_dummy_run_with_lora( - self.lora_config, np.array([num_tokens], dtype=np.int32)): - self._dummy_run(self.kv_caches, num_tokens) + self._dummy_run(self.kv_caches, num_tokens) xm.mark_step() if num_tokens >= self.max_num_tokens: break From b64dc319e1dc9b34421273c4280e4305915c9b7a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Mar 2025 20:42:02 +0000 Subject: [PATCH 116/186] Lint Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 7 +++--- vllm/lora/ops/xla_ops/pallas.py | 36 +++++++++++++++++++----------- vllm/v1/worker/tpu_model_runner.py | 19 +++++++++------- 3 files changed, 38 insertions(+), 24 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index 750def2d5d32..b44ab06af1a9 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -14,7 +14,7 @@ def bgmv_expand(inputs: torch.Tensor, add_inputs: bool = True): outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3), - lora_indices_tensor) + lora_indices_tensor) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -37,7 +37,7 @@ def bgmv_shrink(inputs: torch.Tensor, scaling: float = 1.0): return scaling * torch.ops.xla.bgmv_shrink(inputs, lora_b_weights, - lora_indices_tensor) + lora_indices_tensor) def bgmv_expand_slice(inputs: torch.Tensor, @@ -47,7 +47,8 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): - outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3), lora_indices_tensor) + outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3), + lora_indices_tensor) outputs = F.pad(outputs, (slice_offset, output_tensor.shape[1] - (slice_offset + slice_size), 0, 0)) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 7a70595edf23..c2cf330da103 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -12,12 +12,14 @@ from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, make_kernel_from_pallas) -XLA_LIB.define("bgmv_shrink(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") +XLA_LIB.define( + "bgmv_shrink(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") XLA_LIB.define( "bgmv_expand(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") + def _bgmv_shrink_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, - lora_ref, out_ref, acc_ref, mask_ref): + lora_ref, out_ref, acc_ref, mask_ref): @pl.when(pl.program_id(2) == 0) def _(): @@ -63,7 +65,8 @@ def _bgmv_shrink( N, L, _ = loras.shape return pl.pallas_call( - kernel=functools.partial(_bgmv_shrink_kernel, TOKEN_BLOCK, LORA_BLOCK, N), + kernel=functools.partial(_bgmv_shrink_kernel, TOKEN_BLOCK, LORA_BLOCK, + N), out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, @@ -84,6 +87,7 @@ def _bgmv_shrink( dimension_semantics=("parallel", "parallel", "arbitrary")), name="bgmv")(idxs, inputs, loras) + def bgmv_shrink_shape_function(idxs, inputs, loras): T, _ = inputs.shape _, L, _ = loras.shape @@ -92,7 +96,8 @@ def bgmv_shrink_shape_function(idxs, inputs, loras): @impl(XLA_LIB, "bgmv_shrink", "XLA") -def bgmv_shrink_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): +def bgmv_shrink_xla(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): inputs = inputs.to(dtype=loras.dtype) if len(loras.shape) == 4: @@ -137,9 +142,10 @@ def bgmv_shrink_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTe return kernel(idxs, inputs, loras)[:T, :L] + @impl(XLA_LIB, "bgmv_shrink", "CompositeExplicitAutograd") def bgmv_shrink_non_xla(inputs: torch.Tensor, loras: torch.Tensor, - idxs: torch.IntTensor): + idxs: torch.IntTensor): T, _ = inputs.shape if len(loras.shape) == 4: @@ -155,8 +161,8 @@ def bgmv_shrink_non_xla(inputs: torch.Tensor, loras: torch.Tensor, # transposing. # We only need this for the expand op since the LoRA dimensions in the shrink op # are small enough that the TPU can gather them without a data copy. -def _bgmv_expand_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, - inp_ref, lora_ref, out_ref, acc_ref, mask_ref): +def _bgmv_expand_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, + lora_ref, out_ref, acc_ref, mask_ref): @pl.when(pl.program_id(2) == 0) def _(): @@ -202,8 +208,8 @@ def _bgmv_expand( N, _, L = loras.shape return pl.pallas_call( - kernel=functools.partial(_bgmv_expand_kernel, TOKEN_BLOCK, - LORA_BLOCK, N), + kernel=functools.partial(_bgmv_expand_kernel, TOKEN_BLOCK, LORA_BLOCK, + N), out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, @@ -231,9 +237,10 @@ def bgmv_expand_shape_function(idxs, inputs, loras): return [((T, L), inputs.dtype)] + @impl(XLA_LIB, "bgmv_expand", "XLA") def bgmv_expand_xla(inputs: torch.Tensor, loras: torch.Tensor, - idxs: torch.IntTensor): + idxs: torch.IntTensor): inputs = inputs.to(dtype=loras.dtype) if len(loras.shape) == 4: @@ -252,8 +259,7 @@ def bgmv_expand_xla(inputs: torch.Tensor, loras: torch.Tensor, functools.partial(_bgmv_expand, TOKEN_BLOCK=TOKEN_BLOCK, LORA_BLOCK=LORA_BLOCK, - DIM_BLOCK=DIM_BLOCK), - bgmv_expand_shape_function) + DIM_BLOCK=DIM_BLOCK), bgmv_expand_shape_function) # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU # register. This has to happen in pytorch, doing it in Jax will lead to NaNs @@ -278,9 +284,10 @@ def bgmv_expand_xla(inputs: torch.Tensor, loras: torch.Tensor, return kernel(idxs, inputs, loras)[:T, :L] + @impl(XLA_LIB, "bgmv_expand", "CompositeExplicitAutograd") def bgmv_expand_non_xla(inputs: torch.Tensor, loras: torch.Tensor, - idxs: torch.IntTensor): + idxs: torch.IntTensor): T, _ = inputs.shape if len(loras.shape) == 4: @@ -290,14 +297,17 @@ def bgmv_expand_non_xla(inputs: torch.Tensor, loras: torch.Tensor, return torch.empty((T, L), device=inputs.device) + def largest_divisor(n: int, divs: List[int]) -> int: for div in sorted(divs, reverse=True): if n % div == 0: return div return max(divs) + def next_multiple_of(n: int, mult: int) -> int: return math.ceil(n / mult) * mult + def get_bounded_value(_min: int, val: int, _max: int) -> int: return min(max(_min, val), _max) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index bd35efbb1474..aa7fe64a4a38 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -714,7 +714,7 @@ def load_model(self) -> None: "get_tensor_model_parallel_rank", return_value=xm_tp_rank): model = get_model(vllm_config=self.vllm_config) - if self.lora_config: + if self.lora_config is not None: model = self.load_lora_model(model, self.model_config, self.scheduler_config, self.lora_config, self.device) @@ -775,18 +775,14 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - - punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper - torch._dynamo.mark_dynamic(punica_wrapper._embeddings_indices, 1) - torch._dynamo.mark_dynamic(punica_wrapper._sampler_indices_padded, 0) with self.maybe_dummy_run_with_lora( self.lora_config, np.array([num_tokens], dtype=np.int32)): with set_forward_context(attn_metadata, self.vllm_config, 0): self.model(input_ids=input_ids, - positions=position_ids, - kv_caches=kv_caches, - inputs_embeds=inputs_embeds) + positions=position_ids, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds) def capture_model(self) -> None: """Compile the model.""" @@ -811,6 +807,13 @@ def capture_model(self) -> None: num_tokens = 16 hsize = self.model_config.get_hidden_size() device = self.device + + if self.lora_config is not None: + punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper + torch._dynamo.mark_dynamic(punica_wrapper._embeddings_indices, 1) + torch._dynamo.mark_dynamic(punica_wrapper._sampler_indices_padded, + 0) + # Compile sampling step for different model+sampler outputs in bucketed # n_tokens x max_num_reqs. Graph is really small so this is fine. while True: From 49a81026a024bac344759336001706fae3bcf2b1 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Mar 2025 20:47:40 +0000 Subject: [PATCH 117/186] Minor fixes + lint Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index bd35efbb1474..aa7fe64a4a38 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -714,7 +714,7 @@ def load_model(self) -> None: "get_tensor_model_parallel_rank", return_value=xm_tp_rank): model = get_model(vllm_config=self.vllm_config) - if self.lora_config: + if self.lora_config is not None: model = self.load_lora_model(model, self.model_config, self.scheduler_config, self.lora_config, self.device) @@ -775,18 +775,14 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: torch._dynamo.mark_dynamic(input_ids, 0) torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - - punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper - torch._dynamo.mark_dynamic(punica_wrapper._embeddings_indices, 1) - torch._dynamo.mark_dynamic(punica_wrapper._sampler_indices_padded, 0) with self.maybe_dummy_run_with_lora( self.lora_config, np.array([num_tokens], dtype=np.int32)): with set_forward_context(attn_metadata, self.vllm_config, 0): self.model(input_ids=input_ids, - positions=position_ids, - kv_caches=kv_caches, - inputs_embeds=inputs_embeds) + positions=position_ids, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds) def capture_model(self) -> None: """Compile the model.""" @@ -811,6 +807,13 @@ def capture_model(self) -> None: num_tokens = 16 hsize = self.model_config.get_hidden_size() device = self.device + + if self.lora_config is not None: + punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper + torch._dynamo.mark_dynamic(punica_wrapper._embeddings_indices, 1) + torch._dynamo.mark_dynamic(punica_wrapper._sampler_indices_padded, + 0) + # Compile sampling step for different model+sampler outputs in bucketed # n_tokens x max_num_reqs. Graph is really small so this is fine. while True: From c1be5f9a84d528bf7c0e2825236b2bfc5440c725 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 24 Mar 2025 20:55:05 +0000 Subject: [PATCH 118/186] Lint Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index aa7fe64a4a38..cd361754a185 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -777,12 +777,13 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) with self.maybe_dummy_run_with_lora( - self.lora_config, np.array([num_tokens], dtype=np.int32)): - with set_forward_context(attn_metadata, self.vllm_config, 0): - self.model(input_ids=input_ids, - positions=position_ids, - kv_caches=kv_caches, - inputs_embeds=inputs_embeds) + self.lora_config, + np.array([num_tokens], dtype=np.int32)), set_forward_context( + attn_metadata, self.vllm_config, 0): + self.model(input_ids=input_ids, + positions=position_ids, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds) def capture_model(self) -> None: """Compile the model.""" From bf44d651af871b78650fb4ea2b4bde1e71760148 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 25 Mar 2025 19:49:25 +0000 Subject: [PATCH 119/186] Fixed mark_dynamic placement for eager/compiled modes Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 4 ++++ vllm/v1/worker/tpu_model_runner.py | 10 ++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index fd1e76bc4ff3..cd6c7d70fdaf 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -32,6 +32,10 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self._sampler_indices_padded = self._sampler_indices_padded.to( dtype=torch.int32) + def mark_compiled(self): + torch._dynamo.mark_dynamic(self._embeddings_indices, 1) + torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) + @property def embeddings_indices(self) -> torch.Tensor: """ diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index aa7fe64a4a38..5d429f131695 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -718,6 +718,10 @@ def load_model(self) -> None: model = self.load_lora_model(model, self.model_config, self.scheduler_config, self.lora_config, self.device) + punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper + if not self.enforce_eager: + punica_wrapper.mark_compiled() + model = model.eval() xm.mark_step() xm.wait_device_ops() @@ -808,12 +812,6 @@ def capture_model(self) -> None: hsize = self.model_config.get_hidden_size() device = self.device - if self.lora_config is not None: - punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper - torch._dynamo.mark_dynamic(punica_wrapper._embeddings_indices, 1) - torch._dynamo.mark_dynamic(punica_wrapper._sampler_indices_padded, - 0) - # Compile sampling step for different model+sampler outputs in bucketed # n_tokens x max_num_reqs. Graph is really small so this is fine. while True: From 15ff074e710ed05a9d4ec8db1bb39957ecfbfcaa Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 25 Mar 2025 19:49:25 +0000 Subject: [PATCH 120/186] Fixed mark_dynamic placement for eager/compiled modes Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 4 ++++ vllm/v1/worker/tpu_model_runner.py | 10 ++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index e5cd6d6b1fdf..6b2ccf9148ab 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -32,6 +32,10 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self._sampler_indices_padded = self._sampler_indices_padded.to( dtype=torch.int32) + def mark_compiled(self): + torch._dynamo.mark_dynamic(self._embeddings_indices, 1) + torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) + @property def embeddings_indices(self) -> torch.Tensor: """ diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index cd361754a185..188efcbeb3db 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -718,6 +718,10 @@ def load_model(self) -> None: model = self.load_lora_model(model, self.model_config, self.scheduler_config, self.lora_config, self.device) + punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper + if not self.enforce_eager: + punica_wrapper.mark_compiled() + model = model.eval() xm.mark_step() xm.wait_device_ops() @@ -809,12 +813,6 @@ def capture_model(self) -> None: hsize = self.model_config.get_hidden_size() device = self.device - if self.lora_config is not None: - punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper - torch._dynamo.mark_dynamic(punica_wrapper._embeddings_indices, 1) - torch._dynamo.mark_dynamic(punica_wrapper._sampler_indices_padded, - 0) - # Compile sampling step for different model+sampler outputs in bucketed # n_tokens x max_num_reqs. Graph is really small so this is fine. while True: From d9f89b6b9486e88df4d4892b392aac01943fcf48 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 25 Mar 2025 19:53:17 +0000 Subject: [PATCH 121/186] Temporary fix to LogitsProcessorWithLoRA pipeline bubble issue Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index cd6c7d70fdaf..ab6c653805d0 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -39,7 +39,7 @@ def mark_compiled(self): @property def embeddings_indices(self) -> torch.Tensor: """ - This property provides access to the indices used for lora embeddings, + This property provides access to the indices used for lora embeddings, specifically for VocabParallelEmbeddingWithLoRA. """ return self._embeddings_indices[:] @@ -108,7 +108,6 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x = x.view(-1, x.shape[-1]) new_y = [] - # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): y_s = y[slice_idx] lora_s = lora_a_stacked[slice_idx] @@ -273,7 +272,8 @@ def add_lora_logits(self, scale (float): Scaling factor. buffer (Optional[torch.Tensor]):Default to None. """ - if self.no_lora: + # Temporary fix to pipeline bubble bug + if self.no_lora or lora_a_stacked.sum() == 0: return y y_org = y @@ -282,10 +282,8 @@ def add_lora_logits(self, rank = lora_b_stacked.size(-1) if buffer is None: - # We set the buffer to be float32 by default, consistent with the - # triton op buffer = torch.zeros((x.size(0), rank), - dtype=torch.float32, + dtype=y.dtype, device=x.device) buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, @@ -301,5 +299,4 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 self._lora_indices_per_batch[:self.batch_size].copy_( token_lora_tensor[:self.batch_size]) - # TODO: .item() is extremely inefficient on TPU, so find a way around it self.no_lora = torch.all(token_lora_tensor == -1).item() From 81775d3ef781b3ec99408381b1f6865131b12327 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 25 Mar 2025 19:53:40 +0000 Subject: [PATCH 122/186] Sampler is now compiled with LoRA Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5d429f131695..433b792d5502 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -833,7 +833,10 @@ def capture_model(self) -> None: num_reqs_to_sample, device) logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs_to_sample) - self.model.sample_from_hidden(dummy_hidden, sampling_meta) + + with self.maybe_dummy_run_with_lora(self.lora_config, np.array([num_tokens], dtype=np.int32)): + self.model.sample_from_hidden(dummy_hidden, sampling_meta) + xm.mark_step() if num_reqs_to_sample >= self.max_num_reqs: break From 829028d3b8de60ea1d3dd0618d835fae8af39dc0 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 26 Mar 2025 17:31:07 +0000 Subject: [PATCH 123/186] Removed early exits since they cause eager execution Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index ab6c653805d0..c62158063819 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -58,8 +58,6 @@ def shrink( w_t_all: torch.Tensor, scale: float, ): - if self.no_lora: - return y return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) def expand( @@ -69,8 +67,6 @@ def expand( w_t_all: torch.Tensor, add_inputs: bool, ): - if self.no_lora: - return y return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) def expand_slice( @@ -83,8 +79,6 @@ def expand_slice( y_total_size: int, add_inputs: bool, ) -> torch.Tensor: - if self.no_lora: - return y return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) @@ -272,10 +266,6 @@ def add_lora_logits(self, scale (float): Scaling factor. buffer (Optional[torch.Tensor]):Default to None. """ - # Temporary fix to pipeline bubble bug - if self.no_lora or lora_a_stacked.sum() == 0: - return y - y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) @@ -299,4 +289,3 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 self._lora_indices_per_batch[:self.batch_size].copy_( token_lora_tensor[:self.batch_size]) - self.no_lora = torch.all(token_lora_tensor == -1).item() From 5638e7da75905acebfa3dac3844516c8cd068dc5 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 27 Mar 2025 01:00:53 +0000 Subject: [PATCH 124/186] Removed some recompilations when updating LoRA metadata Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 54 +++++++++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index c62158063819..13822dbb6770 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink +from vllm.lora.punica_wrapper.utils import convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext from .punica_base import PunicaWrapperBase @@ -284,6 +290,52 @@ def add_lora_logits(self, self.sampler_indices, add_inputs=True) return y.view_as(y_org) + + # This performs the same tensor ops as the base method, except it does them + # on the CPU then transfers the results to the TPU + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + # Pad the prompt mapping to avoid running into recompiles on the TPU + pad_len = len(mapping.index_mapping) - len(mapping.prompt_mapping) + padding = [-1] * pad_len + mapping.prompt_mapping = tuple(list(mapping.prompt_mapping) + padding) + + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + "cpu", + long_lora_context, + ) + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices.to(self.device)) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices.to(self.device)) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded.to(self.device)) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices.to(self.device)) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor.to(self.device)) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 From bae61a2c726bb53cfadb629d8b5f0b30d876889a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 27 Mar 2025 17:32:21 +0000 Subject: [PATCH 125/186] Aligned lora codepath with recompilation fixes Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 33 ++++++++++++++++------ vllm/v1/worker/tpu_model_runner.py | 39 ++++++++++++++++++-------- 2 files changed, 52 insertions(+), 20 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 13822dbb6770..65565532094f 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import math from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch @@ -290,8 +291,8 @@ def add_lora_logits(self, self.sampler_indices, add_inputs=True) return y.view_as(y_org) - - # This performs the same tensor ops as the base method, except it does them + + # This performs the same tensor ops as the base method, except it does them # on the CPU then transfers the results to the TPU def _update_base_metadata( self, @@ -303,10 +304,11 @@ def _update_base_metadata( long_lora_context: Optional["LongContextLoRAContext"] = None, ): # Pad the prompt mapping to avoid running into recompiles on the TPU - pad_len = len(mapping.index_mapping) - len(mapping.prompt_mapping) - padding = [-1] * pad_len - mapping.prompt_mapping = tuple(list(mapping.prompt_mapping) + padding) - + # TODO: Should this happen inside mapping internally? If so how can we + # avoid having backend specific LoRAMapping classes? + mapping.prompt_mapping = self._pad_prompt_mapping( + mapping.prompt_mapping) + ( base_indices, sampler_indices, @@ -323,8 +325,10 @@ def _update_base_metadata( "cpu", long_lora_context, ) - self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices.to(self.device)) - self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices.to(self.device)) + self._token_lora_indices[:base_indices.shape[0]].copy_( + base_indices.to(self.device)) + self._sampler_indices[:sampler_indices.shape[0]].copy_( + sampler_indices.to(self.device)) self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( sampler_indices_padded.to(self.device)) self._embeddings_indices[:embeddings_indices. @@ -341,3 +345,16 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 self._lora_indices_per_batch[:self.batch_size].copy_( token_lora_tensor[:self.batch_size]) + + def _pad_prompt_mapping( + self, prompt_mapping: Tuple[int, ...]) -> Tuple[int, ...]: + num_reqs = len(prompt_mapping) + + # From vllm/v1/worker/tppu_model_runner:52, but need to avoid a circular import + MIN_NUM_SEQS = 8 + + padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS) + pad_len = padded_num_reqs - num_reqs + + padding = [-1] * pad_len + return tuple(list(prompt_mapping) + padding) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index a49bd989c3ee..123694bbc527 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -777,12 +777,13 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) with self.maybe_dummy_run_with_lora( - self.lora_config, np.array([num_tokens], dtype=np.int32)): - with set_forward_context(attn_metadata, self.vllm_config, 0): - self.model(input_ids=input_ids, - positions=position_ids, - kv_caches=kv_caches, - inputs_embeds=inputs_embeds) + self.lora_config, + np.array([num_tokens], dtype=np.int32)), set_forward_context( + attn_metadata, self.vllm_config, 0): + self.model(input_ids=input_ids, + positions=position_ids, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds) def capture_model(self) -> None: """Compile the model.""" @@ -810,7 +811,8 @@ def capture_model(self) -> None: dummy_hidden = torch.randn((num_tokens, hsize), device=device, dtype=torch.bfloat16) - while True: + while num_reqs_to_sample <= self.max_num_reqs and \ + num_reqs_to_sample <= num_tokens: indices = torch.zeros( num_reqs_to_sample, dtype=torch.int32, @@ -822,13 +824,14 @@ def capture_model(self) -> None: logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs_to_sample) - with self.maybe_dummy_run_with_lora(self.lora_config, np.array([num_tokens], dtype=np.int32)): - out = self.model.sample_from_hidden(dummy_hidden, - sampling_meta) + with self.maybe_dummy_run_with_lora( + self.lora_config, + _create_dummy_scheduled_tokens(num_tokens, + num_reqs_to_sample)): + out = self.model.sample_from_hidden( + dummy_hidden, sampling_meta) out = out.cpu() - if num_reqs_to_sample >= self.max_num_reqs: - break num_reqs_to_sample *= 2 xm.wait_device_ops() end = time.perf_counter() @@ -1008,3 +1011,15 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: index = bisect.bisect_left(paddings, x) assert index < len(paddings) return paddings[index] + + +def _create_dummy_scheduled_tokens(total_tokens: int, + num_prompts: int) -> np.ndarray: + assert num_prompts <= total_tokens, "Expected num_prompts < total_tokens" + base_tokens = total_tokens // num_prompts + leftover_tokens = total_tokens % num_prompts + + tokens = np.full((num_prompts, ), base_tokens, dtype=np.int32) + tokens[-1] += leftover_tokens + + return tokens From dc8b94076bc4e76ff6edacbe73765f4182a49494 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 27 Mar 2025 17:58:56 +0000 Subject: [PATCH 126/186] Disabled add_lora_logits temporarily Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 65565532094f..11296b68dc82 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -273,6 +273,8 @@ def add_lora_logits(self, scale (float): Scaling factor. buffer (Optional[torch.Tensor]):Default to None. """ + return y + y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) From eb804a0a898c3386772ae1dd832b38e57f5e6748 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 27 Mar 2025 20:52:56 +0000 Subject: [PATCH 127/186] Added the LoRA Laning optimisation + tests + explanation Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 29 ++- vllm/lora/ops/xla_ops/lora_ops.py | 12 +- vllm/lora/ops/xla_ops/pallas.py | 234 +++++++++++++++++++++---- vllm/lora/punica_wrapper/punica_tpu.py | 15 +- 4 files changed, 243 insertions(+), 47 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index 1098f8e3d47e..d253351158a8 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -30,12 +30,12 @@ def generate_test_data(T, D, L, N, seed, dtype=torch.float32): D: Input dim L: LoRA Dim N: N LoRAs - + Outputs: inputs: torch.Tensor - shape (T, D) loras: torch.Tensor - shape (N, 1, L, D) idxs: torch.Tensor - shape (T, ) - all values must be in [0, N) - + ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T """ torch.manual_seed(seed) @@ -84,3 +84,28 @@ def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): # Compare with reference output assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) + +# Parameterize tests with various shapes and dtypes +@pytest.mark.parametrize("T", N_TOKENS) +@pytest.mark.parametrize("D", HIDDEN_SIZES) +@pytest.mark.parametrize("L", RANKS) +@pytest.mark.parametrize("N", NUM_LORA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", [0]) +def test_lora_laning_correctness(T, D, L, N, dtype, seed): + inputs, loras_a, idxs, _ = generate_test_data(T, D, L, N, seed, dtype) + _, loras_b, _, _ = generate_test_data(T, L, D, N, seed, dtype) + + r1 = ref_bgmv(inputs, loras_a, idxs) + r2 = ref_bgmv(r1, loras_b, idxs) + + o1 = torch.ops.xla.bgmv_shrink(inputs, loras_a, idxs) + o2 = torch.ops.xla.bgmv_expand( + o1, + loras_b.transpose(2, 3), + idxs, + True + ) + + # Compare with reference output + assert torch.allclose(o2, r2, rtol=1e-2, atol=1e-2) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index b44ab06af1a9..f5088c37745f 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -11,10 +11,12 @@ def bgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, - add_inputs: bool = True): + add_inputs: bool = True, + *, + enable_laning: bool = False): outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3), - lora_indices_tensor) + lora_indices_tensor, enable_laning) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -46,9 +48,11 @@ def bgmv_expand_slice(inputs: torch.Tensor, lora_indices_tensor: torch.Tensor, slice_offset: int, slice_size: int, - add_inputs: bool = True): + add_inputs: bool = True, + *, + enable_laning: bool = False): outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3), - lora_indices_tensor) + lora_indices_tensor, enable_laning) outputs = F.pad(outputs, (slice_offset, output_tensor.shape[1] - (slice_offset + slice_size), 0, 0)) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index c2cf330da103..17bafef43ccb 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -14,12 +14,138 @@ XLA_LIB.define( "bgmv_shrink(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") -XLA_LIB.define( - "bgmv_expand(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") - -def _bgmv_shrink_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, - lora_ref, out_ref, acc_ref, mask_ref): +# bgmv_expand needs a flag to enable LoRA laning since it expects its inputs to +# be the outputs of a LoRA laned bgmv_shrink. This is not always the case when +# we use bgmv_expand +XLA_LIB.define( + "bgmv_expand(Tensor inputs, Tensor loras, Tensor idxs, bool enable_laning) -> Tensor" +) +""" +LoRA Laning Optimization for TPU Matrix Multiplication + +When we run with the TPU we need to keep its MXU (matrix multiplication unit) +well fed to achieve maximum utilisation. +The MXU can perform an (8x128) by (128x128) matmul once every 8 cycles. + +LoRA computations typically take a series of T (1xD) vectors and matmul them +with a (DxL) matrix (shrinking) followed by another matmul with a (LxD) matrix +(expanding). Grouping the vectors we get a (TxD) matrix, so our computations +become matmul((TxD), (DxL)) and matmul((TxL), (LxD)). + +The number of tokens (T) and the hidden dimension (D) are usually greater than +8 and 128 respectively, however the LoRA rank (L) is usually a smaller value, +around 8-64, which means we need to pad L to allow it to fit in a TPU register. + + +------------------+ + | Shrink Operation | + +------------------+ + + L + +------------------+ + D | 1111000000000000 | L + +------------------+ | 1111000000000000 | +------------------+ + | 1111111111111111 | | 1111000000000000 | | 1111000000000000 | + T | 2222222222222222 | x D | 1111000000000000 | = T | 1111000000000000 | + +------------------+ | 1111000000000000 | +------------------+ + | 1111000000000000 | + | 1111000000000000 | + | 1111000000000000 | + +------------------+ + +Here we have 4 tokens each needing a different LoRA adapter, and 1 LoRA adapter +loaded into the MXU. After the matmul we end up with the result of applying +LoRA 1 to all T tokens, but since only one token needs LoRA 1, we mask out +everything we don't need to get: + + D + +------------------+ + | 1111000000000000 | + | 0000000000000000 | + +------------------+ + +However, we need: + + L + +------------------+ + | 1111000000000000 | + | 2222000000000000 | + +------------------+ + +So we'll have to perform another matmul. +Overall this shrink wastes time and memory padding the LoRA adapters and running +extra matmuls. + +We can get both reduce the number of matmuls used and the amount of applied +padding by grouping the LoRA adapters into multiple "lanes". + + L + +------------------+ + D | 1111222200000000 | L + +------------------+ | 1111222200000000 | +------------------+ + | 1111111111111111 | | 1111222200000000 | | 1111222200000000 | + T | 2222222222222222 | x D | 1111222200000000 | = T | 1111222200000000 | + +------------------+ | 1111222200000000 | +------------------+ + | 1111222200000000 | + | 1111222200000000 | + | 1111222200000000 | + +------------------+ + + +Now we're able to compute the outputs of 4 different LoRA adapters in the same +8 cycles. However we don't need all these results so we'll again mask out +everything we don't need to get: + + L + +------------------+ + | 1111000000000000 | + | 0000222200000000 | + +------------------+ + +But now our outputs aren't aligned properly, so we would need to apply an extra +shuffle operation. + + +------------------+ + | Expand Operation | + +------------------+ + +When expanding we end up wasting space in both matrix registers. + + D + +------------------+ + L | 1111111111111111 | D + +------------------+ | 1111111111111111 | +------------------+ + | 1111000000000000 | | 1111111111111111 | | 1111111111111111 | + T | 2222000000000000 | x L | 1111111111111111 | = T | 1111111111111111 | + +------------------+ | 0000000000000000 | +------------------+ + | 0000000000000000 | + | 0000000000000000 | + | 0000000000000000 | + +------------------+ + +But, if we use LoRA Laning like before, we can waste less space. We would also +have to shuffle the input so it applies to the right adapter. + + D + +------------------+ + L | 1111111111111111 | D + +------------------+ | 1111111111111111 | +------------------+ + | 1111000000000000 | | 1111111111111111 | | 1111111111111111 | + T | 0000222200000000 | x L | 1111111111111111 | = T | 2222222222222222 | + +------------------+ | 2222222222222222 | +------------------+ + | 2222222222222222 | + | 2222222222222222 | + | 2222222222222222 | + +------------------+ + +Since this shuffling is the exact opposite of the operation we do at the end of +the Shrink operation, we can skip both shuffles. + +""" + +def _bgmv_shrink_kernel(bT: int, bL: int, n_lora_lanes: int, lane_size: int, + max_num_loras: int, idx_ref, inp_ref, lora_ref, + out_ref, acc_ref, mask_ref): @pl.when(pl.program_id(2) == 0) def _(): @@ -27,23 +153,33 @@ def _(): t = pl.program_id(0) - ones = jnp.ones((bL, ), dtype=jnp.float32) + ones = jnp.ones((lane_size, ), dtype=jnp.float32) - for i in range(max_num_loras): + base_lora_idx = 0 + for lane_idx in range(max_num_loras): mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) valid = False for j in range(bT): - valid |= idx_ref[j + bT * t] == i + idx = idx_ref[j + bT * t] + for k in range(n_lora_lanes): + lora_idx = base_lora_idx + k + set_mask = idx == lora_idx + valid |= set_mask - @pl.when(idx_ref[j + bT * t] == i) - def _(): - mask_ref.at[j, :].set(ones) + @pl.when(set_mask) + def _(): + lane_start = k * lane_size + lane_end = lane_start + lane_size + + mask_ref.at[j, lane_start:lane_end].set(ones) + + base_lora_idx += n_lora_lanes @pl.when(valid) def _(): acc_ref[...] += jax.lax.dot_general( inp_ref[...], - lora_ref[i, ...], (((1, ), (1, )), ((), ())), + lora_ref[lane_idx, ...], (((1, ), (1, )), ((), ())), preferred_element_type=jnp.float32) * mask_ref[...] @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) @@ -52,7 +188,10 @@ def _(): @functools.partial(jax.jit, - static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"]) + static_argnames=[ + "TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK", + "N_LORA_LANES", "LANE_SIZE" + ]) def _bgmv_shrink( idxs: jax.Array, # (T, ) int32 inputs: jax.Array, # (T, D) model dtype @@ -60,13 +199,15 @@ def _bgmv_shrink( *, TOKEN_BLOCK: int, LORA_BLOCK: int, - DIM_BLOCK: int) -> jax.Array: # (T, L) model dtype + DIM_BLOCK: int, + N_LORA_LANES: int, + LANE_SIZE: int) -> jax.Array: # (T, L) model dtype T, D = inputs.shape N, L, _ = loras.shape return pl.pallas_call( kernel=functools.partial(_bgmv_shrink_kernel, TOKEN_BLOCK, LORA_BLOCK, - N), + N_LORA_LANES, LANE_SIZE, N), out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=1, @@ -104,27 +245,29 @@ def bgmv_shrink_xla(inputs: torch.Tensor, loras: torch.Tensor, loras = loras.squeeze(axis=1) T, _ = inputs.shape - _, L, D = loras.shape - is_expand = L > D - - jax_import_guard() + N, L, D = loras.shape TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) LORA_BLOCK = 256 DIM_BLOCK = min(1024, next_multiple_of(D, 256)) - kernel = make_kernel_from_pallas( - functools.partial(_bgmv_shrink, - TOKEN_BLOCK=TOKEN_BLOCK, - LORA_BLOCK=LORA_BLOCK, - DIM_BLOCK=DIM_BLOCK), bgmv_shrink_shape_function) + # See if we can fit multiple LoRAs in a register. This would activate LoRA + # laning + N_LORA_LANES = math.ceil(LORA_BLOCK / L) + LANE_SIZE = min(L, LORA_BLOCK) + if N_LORA_LANES > 1 and N > 1: + pad_N = next_multiple_of(N, N_LORA_LANES) - N + new_N = N + pad_N + + loras = torch.nn.functional.pad(loras, (0, 0, 0, 0, 0, pad_N)) + loras = loras.reshape((new_N // N_LORA_LANES, LORA_BLOCK, D)) + N, L, D = loras.shape # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU # register. This has to happen in pytorch, doing it in Jax will lead to NaNs pad_L = 0 if LORA_BLOCK > L or L % LORA_BLOCK != 0: pad_L = next_multiple_of(L, LORA_BLOCK) - L - pad_D = 0 if DIM_BLOCK > D or D % DIM_BLOCK != 0: pad_D = next_multiple_of(D, DIM_BLOCK) - D @@ -140,6 +283,15 @@ def bgmv_shrink_xla(inputs: torch.Tensor, loras: torch.Tensor, if pad_T != T: idxs = torch.nn.functional.pad(idxs, ((0, pad_T))) + jax_import_guard() + kernel = make_kernel_from_pallas( + functools.partial(_bgmv_shrink, + TOKEN_BLOCK=TOKEN_BLOCK, + LORA_BLOCK=LORA_BLOCK, + DIM_BLOCK=DIM_BLOCK, + N_LORA_LANES=N_LORA_LANES, + LANE_SIZE=LANE_SIZE), bgmv_shrink_shape_function) + return kernel(idxs, inputs, loras)[:T, :L] @@ -240,26 +392,30 @@ def bgmv_expand_shape_function(idxs, inputs, loras): @impl(XLA_LIB, "bgmv_expand", "XLA") def bgmv_expand_xla(inputs: torch.Tensor, loras: torch.Tensor, - idxs: torch.IntTensor): + idxs: torch.IntTensor, enable_laning: bool): inputs = inputs.to(dtype=loras.dtype) if len(loras.shape) == 4: loras = loras.squeeze(axis=1) T, _ = inputs.shape - _, D, L = loras.shape - - jax_import_guard() + N, D, L = loras.shape TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) LORA_BLOCK = min(1024, next_multiple_of(L, 256)) DIM_BLOCK = 256 - kernel = make_kernel_from_pallas( - functools.partial(_bgmv_expand, - TOKEN_BLOCK=TOKEN_BLOCK, - LORA_BLOCK=LORA_BLOCK, - DIM_BLOCK=DIM_BLOCK), bgmv_expand_shape_function) + # See if we can fit multiple LoRAs in a register. This would activate LoRA + # laning + N_LORA_LANES = math.ceil(DIM_BLOCK / D) + if enable_laning and N_LORA_LANES > 1 and N > 1: + pad_N = next_multiple_of(N, N_LORA_LANES) - N + new_N = N + pad_N + + loras = torch.nn.functional.pad(loras, (0, 0, 0, 0, 0, pad_N)) + loras = loras.reshape((new_N // N_LORA_LANES, DIM_BLOCK, L)) + idxs = idxs // N_LORA_LANES + N, D, L = loras.shape # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU # register. This has to happen in pytorch, doing it in Jax will lead to NaNs @@ -282,12 +438,20 @@ def bgmv_expand_xla(inputs: torch.Tensor, loras: torch.Tensor, if pad_T != T: idxs = torch.nn.functional.pad(idxs, ((0, pad_T))) + jax_import_guard() + + kernel = make_kernel_from_pallas( + functools.partial(_bgmv_expand, + TOKEN_BLOCK=TOKEN_BLOCK, + LORA_BLOCK=LORA_BLOCK, + DIM_BLOCK=DIM_BLOCK), bgmv_expand_shape_function) + return kernel(idxs, inputs, loras)[:T, :L] @impl(XLA_LIB, "bgmv_expand", "CompositeExplicitAutograd") def bgmv_expand_non_xla(inputs: torch.Tensor, loras: torch.Tensor, - idxs: torch.IntTensor): + idxs: torch.IntTensor, enable_laning: bool): T, _ = inputs.shape if len(loras.shape) == 4: diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 11296b68dc82..500dd53b713a 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -73,8 +73,9 @@ def expand( x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool, + enable_laning: bool ): - return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs, enable_laning=enable_laning) def expand_slice( self, @@ -83,11 +84,11 @@ def expand_slice( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, - y_total_size: int, add_inputs: bool, + enable_laning: bool ) -> torch.Tensor: return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, - y_offset, y_slice_size, add_inputs) + y_offset, y_slice_size, add_inputs, enable_laning=enable_laning) def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], @@ -162,8 +163,8 @@ def add_expand(self, lora_b_stacked[slice_idx], offset_left, output_slices[slice_idx], - y_total_size=sum(output_slices), add_inputs=add_inputs, + enable_laning=kwargs["enable_laning"] ) offset_left += output_slices[slice_idx] return y.view_as(y_org) @@ -188,7 +189,7 @@ def add_lora_embedding(self, """ # Embedding layer only needs the expand op - return self.expand(y, x, lora_b_stacked, add_inputs) + return self.expand(y, x, lora_b_stacked, add_inputs, enable_laning=False) def add_lora_linear(self, y: torch.Tensor, @@ -247,6 +248,7 @@ def add_lora_linear(self, None, output_slices, add_inputs=True, + enable_laning=True, **kwargs) def add_lora_logits(self, @@ -291,7 +293,8 @@ def add_lora_logits(self, lora_b_stacked, y, self.sampler_indices, - add_inputs=True) + add_inputs=True, + enable_laning=True) return y.view_as(y_org) # This performs the same tensor ops as the base method, except it does them From fbb902ae343d601cecf4f0a5f992babf99e83a83 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 27 Mar 2025 21:01:16 +0000 Subject: [PATCH 128/186] Updated kernel benchmarking script with lora laning Signed-off-by: Akshat Tripathi --- bmark_kernels.py | 43 ++++++++++++++++++++----------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/bmark_kernels.py b/bmark_kernels.py index c1bc08ced9af..cb714830bbfd 100644 --- a/bmark_kernels.py +++ b/bmark_kernels.py @@ -26,32 +26,30 @@ def create_tensors(T, D, L, N, dtype=torch.bfloat16, device='xla'): return inputs, loras, idxs -# SEQ_LENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 131072] -# HIDDEN_DIM = [256, 1024, 4096, 8192, 14336, 28672] -# LORA_RANKS = [8, 16, 32, 64, 128, 128] -# N_LORAS = [1, 2, 4, 8] - - -SEQ_LENS = [1024] -HIDDEN_DIM = [4096] -LORA_RANKS = [32] -N_LORAS = [1] +SEQ_LENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 131072] +HIDDEN_DIM = [256, 1024, 4096, 8192, 14336, 28672] +LORA_RANKS = [8, 16, 32, 64, 128, 128] +N_LORAS = [1, 2, 4, 8] @torch.compile(fullgraph=True, dynamic=False, backend="openxla") def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): return torch.einsum("td,tld->tl", inputs, loras[idxs]) @torch.compile(fullgraph=True, dynamic=False, backend="openxla") -def bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): - return torch.ops.xla.bgmv(inputs, loras, idxs) +def bgmv_shrink(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): + return torch.ops.xla.bgmv_shrink(inputs, loras, idxs) + +@torch.compile(fullgraph=True, dynamic=False, backend="openxla") +def bgmv_expand(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor, enable_laning: bool): + return torch.ops.xla.bgmv_expand(inputs, loras, idxs, enable_laning) @torch.compile(fullgraph=True, dynamic=False, backend="openxla") def shrink_and_expand(inputs: torch.Tensor, loras_a: torch.Tensor, loras_b: torch.Tensor, idxs: torch.IntTensor): - # TODO: Fuse kernels - return bgmv( - bgmv(inputs, loras_a, idxs), + return bgmv_expand( + bgmv_shrink(inputs, loras_a, idxs), loras_b, - idxs + idxs, + enable_laning=True ) @torch.compile(fullgraph=True, dynamic=False, backend="openxla") @@ -69,24 +67,23 @@ def run_and_wait_torch(func, *args): return out @pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) -@pytest.mark.parametrize("func", [bgmv, ref_bgmv]) +@pytest.mark.parametrize("func", [bgmv_shrink]) def test_bmark_shrink(benchmark, T, D, L, N, func): inputs, loras, idxs = create_tensors(T, D, L, N) - benchmark.pedantic(partial(run_and_wait_torch, func), args=(inputs, loras, idxs), rounds=5, warmup_rounds=5, iterations=100) + benchmark.pedantic(partial(run_and_wait_torch, func), args=(inputs, loras, idxs), rounds=5, warmup_rounds=5, iterations=10) @pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, LORA_RANKS, HIDDEN_DIM, N_LORAS)) -@pytest.mark.parametrize("func", [bgmv, ref_bgmv]) +@pytest.mark.parametrize("func", [bgmv_expand]) def test_bmark_expand(benchmark, T, D, L, N, func): inputs, loras, idxs = create_tensors(T, D, L, N) - benchmark.pedantic(partial(run_and_wait_torch, func), args=(inputs, loras, idxs), rounds=5, warmup_rounds=5, iterations=100) - + benchmark.pedantic(partial(run_and_wait_torch, func), args=(inputs, loras, idxs), rounds=5, warmup_rounds=5, iterations=10) @pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) -@pytest.mark.parametrize("func", [shrink_and_expand, ref_shrink_and_expand]) +@pytest.mark.parametrize("func", [shrink_and_expand]) def test_bmark_shrink_and_expand(benchmark, T, D, L, N, func): inputs, loras_a, idxs = create_tensors(T, D, L, N) _, loras_b, _ = create_tensors(T, L, D, N) - benchmark.pedantic(partial(run_and_wait_torch, func), args=(inputs, loras_a, loras_b, idxs), rounds=5, warmup_rounds=5, iterations=100) + benchmark.pedantic(partial(run_and_wait_torch, func), args=(inputs, loras_a, loras_b, idxs), rounds=5, warmup_rounds=5, iterations=10) From 8ba274936b5103bb30e35e185240bc7eb3f8a78d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 27 Mar 2025 21:38:18 +0000 Subject: [PATCH 129/186] Added error for when someone tries to use LoRA adapters on the V0 TPU backend Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_worker.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 66911790662e..e298af5b197e 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -53,6 +53,12 @@ def __init__( if self.model_config.seed is None: self.model_config.seed = 0 + + if vllm_config.lora_config is not None: + raise NotImplementedError( + """The V0 TPU backend doesn't support LoRA serving, please try \ + V1 by setting VLLM_USE_V1=1""" + ) def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" From 51d87a528f117491541de18eabda50f0524b81ea Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 27 Mar 2025 21:40:10 +0000 Subject: [PATCH 130/186] Added test to buildkite Signed-off-by: Akshat Tripathi --- .buildkite/run-tpu-v1-test.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.buildkite/run-tpu-v1-test.sh b/.buildkite/run-tpu-v1-test.sh index 6e1f79ae649e..8c39fbd714b2 100755 --- a/.buildkite/run-tpu-v1-test.sh +++ b/.buildkite/run-tpu-v1-test.sh @@ -33,6 +33,8 @@ docker run --privileged --net host --shm-size=16G -it \ && python3 /workspace/vllm/examples/offline_inference/tpu.py \ && echo TEST_6 \ && pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py" \ + && echo TEST_7 \ + && pytest -s -v /workspace/vllm/tests/tpu/test_lora.py" \ # TODO: This test fails because it uses RANDOM_SEED sampling From 8b1dae8b22f52a97b357d0e3e0399f39743f6ab7 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 27 Mar 2025 21:52:52 +0000 Subject: [PATCH 131/186] Lint Signed-off-by: Akshat Tripathi --- vllm/worker/tpu_worker.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index f888c0228e77..2e9fe3d6a4d4 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -53,12 +53,11 @@ def __init__( if self.model_config.seed is None: self.model_config.seed = 0 - + if vllm_config.lora_config is not None: raise NotImplementedError( """The V0 TPU backend doesn't support LoRA serving, please try \ - V1 by setting VLLM_USE_V1=1""" - ) + V1 by setting VLLM_USE_V1=1""") def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" From aad109b4cfa8ada4c4271a5c1b308816dcccf1d9 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 27 Mar 2025 22:07:09 +0000 Subject: [PATCH 132/186] Optimised single lora kernels Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 108 ++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 48 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 17bafef43ccb..21daedc96711 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -143,6 +143,7 @@ """ + def _bgmv_shrink_kernel(bT: int, bL: int, n_lora_lanes: int, lane_size: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, mask_ref): @@ -151,36 +152,42 @@ def _bgmv_shrink_kernel(bT: int, bL: int, n_lora_lanes: int, lane_size: int, def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) - t = pl.program_id(0) - - ones = jnp.ones((lane_size, ), dtype=jnp.float32) - - base_lora_idx = 0 - for lane_idx in range(max_num_loras): - mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) - valid = False - for j in range(bT): - idx = idx_ref[j + bT * t] - for k in range(n_lora_lanes): - lora_idx = base_lora_idx + k - set_mask = idx == lora_idx - valid |= set_mask - - @pl.when(set_mask) - def _(): - lane_start = k * lane_size - lane_end = lane_start + lane_size - - mask_ref.at[j, lane_start:lane_end].set(ones) - - base_lora_idx += n_lora_lanes - - @pl.when(valid) - def _(): - acc_ref[...] += jax.lax.dot_general( - inp_ref[...], - lora_ref[lane_idx, ...], (((1, ), (1, )), ((), ())), - preferred_element_type=jnp.float32) * mask_ref[...] + if max_num_loras == 1 and n_lora_lanes == 1: + acc_ref[...] += jax.lax.dot_general(inp_ref[...], + lora_ref[0, ...], + (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) + else: + t = pl.program_id(0) + + ones = jnp.ones((lane_size, ), dtype=jnp.float32) + + base_lora_idx = 0 + for lane_idx in range(max_num_loras): + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + valid = False + for j in range(bT): + idx = idx_ref[j + bT * t] + for k in range(n_lora_lanes): + lora_idx = base_lora_idx + k + set_mask = idx == lora_idx + valid |= set_mask + + @pl.when(set_mask) + def _(): + lane_start = k * lane_size + lane_end = lane_start + lane_size + + mask_ref.at[j, lane_start:lane_end].set(ones) + + base_lora_idx += n_lora_lanes + + @pl.when(valid) + def _(): + acc_ref[...] += jax.lax.dot_general( + inp_ref[...], + lora_ref[lane_idx, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_ref[...] @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) def _(): @@ -320,30 +327,35 @@ def _bgmv_expand_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) - t = pl.program_id(0) + if max_num_loras == 1: + acc_ref[...] += jax.lax.dot(inp_ref[...], + lora_ref[0, ...], + preferred_element_type=jnp.float32) + else: + t = pl.program_id(0) - ones = jnp.ones((bL, ), dtype=jnp.float32) + ones = jnp.ones((bL, ), dtype=jnp.float32) - for i in range(max_num_loras): - mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) - valid = False - for j in range(bT): - valid |= idx_ref[j + bT * t] == i + for i in range(max_num_loras): + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + valid = False + for j in range(bT): + valid |= idx_ref[j + bT * t] == i + + @pl.when(idx_ref[j + bT * t] == i) + def _(): + mask_ref.at[j, :].set(ones) - @pl.when(idx_ref[j + bT * t] == i) + @pl.when(valid) def _(): - mask_ref.at[j, :].set(ones) + acc_ref[...] += jax.lax.dot( + inp_ref[...], + lora_ref[i, ...], + preferred_element_type=jnp.float32) * mask_ref[...] - @pl.when(valid) + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) def _(): - acc_ref[...] += jax.lax.dot( - inp_ref[...], - lora_ref[i, ...], - preferred_element_type=jnp.float32) * mask_ref[...] - - @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) - def _(): - out_ref[...] = acc_ref[...].astype(out_ref.dtype) + out_ref[...] = acc_ref[...].astype(out_ref.dtype) @functools.partial(jax.jit, From b09d595a1a36de1d9c2253e0c06a0459edf4285d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 27 Mar 2025 23:03:02 +0000 Subject: [PATCH 133/186] Fixed compilation bug Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index df512b58ff09..cc50b0f53a0b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -815,8 +815,7 @@ def capture_model(self) -> None: device=device, dtype=torch.bfloat16) # Compile for [8, 16, .., 128,.., `self.max_num_reqs`] - while num_reqs_to_sample <= self.max_num_reqs and \ - num_reqs_to_sample <= num_tokens: + while num_reqs_to_sample <= num_tokens: indices = torch.zeros( num_reqs_to_sample, dtype=torch.int32, @@ -836,6 +835,8 @@ def capture_model(self) -> None: dummy_hidden, sampling_meta) out = out.cpu() + if num_reqs_to_sample >= self.max_num_reqs: + break # Make sure to compile the `max_num_reqs` upper-limit case num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit( num_reqs_to_sample + 1, self.max_num_reqs) From 72d95c64e8cc0ea3a4aa96ab3f967ac24f5d1e9d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 31 Mar 2025 04:20:35 +0000 Subject: [PATCH 134/186] Fixed LoRA Laning bug Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 21daedc96711..b50ba7f382ac 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -310,7 +310,12 @@ def bgmv_shrink_non_xla(inputs: torch.Tensor, loras: torch.Tensor, if len(loras.shape) == 4: loras = loras.squeeze(axis=1) - _, L, _ = loras.shape + N, L, _ = loras.shape + + LORA_BLOCK = 256 + N_LORA_LANES = math.ceil(LORA_BLOCK / L) + if N_LORA_LANES > 1 and N > 1: + L = LORA_BLOCK return torch.empty((T, L), device=inputs.device) From be0915c4ffe2a28b3d735357780964dbac602447 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 31 Mar 2025 19:55:55 +0000 Subject: [PATCH 135/186] Fixed extra recompilations Signed-off-by: Akshat Tripathi --- vllm/adapter_commons/utils.py | 10 ++++++++++ vllm/v1/worker/tpu_model_runner.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py index c2dc5433cc65..387c8890ef67 100644 --- a/vllm/adapter_commons/utils.py +++ b/vllm/adapter_commons/utils.py @@ -2,6 +2,8 @@ from typing import Any, Callable, Dict, Optional, Set +from vllm.platforms import current_platform + ## model functions def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], @@ -52,6 +54,14 @@ def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], apply_adapters_func, set_adapter_mapping_func) -> None: apply_adapters_func(requests) + + # We need this to ensure that adapter loading/unloading and updating + # metadata are compiled as 2 separate graphs on TPU, otherwise we get + # runtime recompilations. + if current_platform.is_tpu(): + import torch_xla.core.xla_model as xm + xm.mark_step() + set_adapter_mapping_func(mapping) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 41480f7b7ca3..2df9afbfc752 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -801,6 +801,10 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: inputs_embeds=inputs_embeds) self._hidden_states_dtype = out.dtype + def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, lora_requests) -> None: + super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, lora_requests) + xm.mark_step() + def capture_model(self) -> None: """Compile the model.""" From 478a8bbf162aaaca60b09304afd20c280c93bd0a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 31 Mar 2025 19:57:29 +0000 Subject: [PATCH 136/186] Lint Signed-off-by: Akshat Tripathi --- bmark_kernels.py | 84 +++++++++++++++++--------- tests/lora/tpu/test_pallas_kernels.py | 8 +-- vllm/adapter_commons/utils.py | 4 +- vllm/lora/ops/xla_ops/pallas.py | 2 +- vllm/lora/punica_wrapper/punica_tpu.py | 66 ++++++++++---------- 5 files changed, 92 insertions(+), 72 deletions(-) diff --git a/bmark_kernels.py b/bmark_kernels.py index cb714830bbfd..80a022da8f40 100644 --- a/bmark_kernels.py +++ b/bmark_kernels.py @@ -1,10 +1,11 @@ -from functools import partial +# SPDX-License-Identifier: Apache-2.0 import itertools -import pytest +from functools import partial +import pytest import torch import torch_xla.core.xla_model as xm -import vllm.lora.ops.xla_ops.pallas as pl + def create_tensors(T, D, L, N, dtype=torch.bfloat16, device='xla'): """ @@ -19,46 +20,51 @@ def create_tensors(T, D, L, N, dtype=torch.bfloat16, device='xla'): loras: torch.Tensor - shape (N, 1, L, D) idxs: torch.IntTensor - shape (T, ) - all values must be in [0, N) """ - + inputs = torch.randn((T, D), dtype=dtype, device=device) loras = torch.randn((N, L, D), dtype=dtype, device=device) - idxs = torch.randint(0, N, (T,), dtype=torch.int32, device=device) - + idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device=device) + return inputs, loras, idxs + SEQ_LENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 131072] HIDDEN_DIM = [256, 1024, 4096, 8192, 14336, 28672] LORA_RANKS = [8, 16, 32, 64, 128, 128] N_LORAS = [1, 2, 4, 8] + @torch.compile(fullgraph=True, dynamic=False, backend="openxla") def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): return torch.einsum("td,tld->tl", inputs, loras[idxs]) + @torch.compile(fullgraph=True, dynamic=False, backend="openxla") -def bgmv_shrink(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): +def bgmv_shrink(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): return torch.ops.xla.bgmv_shrink(inputs, loras, idxs) + @torch.compile(fullgraph=True, dynamic=False, backend="openxla") -def bgmv_expand(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor, enable_laning: bool): +def bgmv_expand(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor, enable_laning: bool): return torch.ops.xla.bgmv_expand(inputs, loras, idxs, enable_laning) + @torch.compile(fullgraph=True, dynamic=False, backend="openxla") -def shrink_and_expand(inputs: torch.Tensor, loras_a: torch.Tensor, loras_b: torch.Tensor, idxs: torch.IntTensor): - return bgmv_expand( - bgmv_shrink(inputs, loras_a, idxs), - loras_b, - idxs, - enable_laning=True - ) +def shrink_and_expand(inputs: torch.Tensor, loras_a: torch.Tensor, + loras_b: torch.Tensor, idxs: torch.IntTensor): + return bgmv_expand(bgmv_shrink(inputs, loras_a, idxs), + loras_b, + idxs, + enable_laning=True) + @torch.compile(fullgraph=True, dynamic=False, backend="openxla") -def ref_shrink_and_expand(inputs: torch.Tensor, loras_a: torch.Tensor, loras_b: torch.Tensor, idxs: torch.IntTensor): - return ref_bgmv( - ref_bgmv(inputs, loras_a, idxs), - loras_b, - idxs - ) +def ref_shrink_and_expand(inputs: torch.Tensor, loras_a: torch.Tensor, + loras_b: torch.Tensor, idxs: torch.IntTensor): + return ref_bgmv(ref_bgmv(inputs, loras_a, idxs), loras_b, idxs) + def run_and_wait_torch(func, *args): out = func(*args) @@ -66,24 +72,42 @@ def run_and_wait_torch(func, *args): xm.wait_device_ops() return out -@pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) + +@pytest.mark.parametrize( + "T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) @pytest.mark.parametrize("func", [bgmv_shrink]) def test_bmark_shrink(benchmark, T, D, L, N, func): inputs, loras, idxs = create_tensors(T, D, L, N) - - benchmark.pedantic(partial(run_and_wait_torch, func), args=(inputs, loras, idxs), rounds=5, warmup_rounds=5, iterations=10) -@pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, LORA_RANKS, HIDDEN_DIM, N_LORAS)) + benchmark.pedantic(partial(run_and_wait_torch, func), + args=(inputs, loras, idxs), + rounds=5, + warmup_rounds=5, + iterations=10) + + +@pytest.mark.parametrize( + "T,D,L,N", itertools.product(SEQ_LENS, LORA_RANKS, HIDDEN_DIM, N_LORAS)) @pytest.mark.parametrize("func", [bgmv_expand]) def test_bmark_expand(benchmark, T, D, L, N, func): inputs, loras, idxs = create_tensors(T, D, L, N) - - benchmark.pedantic(partial(run_and_wait_torch, func), args=(inputs, loras, idxs), rounds=5, warmup_rounds=5, iterations=10) -@pytest.mark.parametrize("T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) + benchmark.pedantic(partial(run_and_wait_torch, func), + args=(inputs, loras, idxs), + rounds=5, + warmup_rounds=5, + iterations=10) + + +@pytest.mark.parametrize( + "T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) @pytest.mark.parametrize("func", [shrink_and_expand]) def test_bmark_shrink_and_expand(benchmark, T, D, L, N, func): inputs, loras_a, idxs = create_tensors(T, D, L, N) _, loras_b, _ = create_tensors(T, L, D, N) - - benchmark.pedantic(partial(run_and_wait_torch, func), args=(inputs, loras_a, loras_b, idxs), rounds=5, warmup_rounds=5, iterations=10) + + benchmark.pedantic(partial(run_and_wait_torch, func), + args=(inputs, loras_a, loras_b, idxs), + rounds=5, + warmup_rounds=5, + iterations=10) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index d253351158a8..5d9b64386db2 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -85,6 +85,7 @@ def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): # Compare with reference output assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) + # Parameterize tests with various shapes and dtypes @pytest.mark.parametrize("T", N_TOKENS) @pytest.mark.parametrize("D", HIDDEN_SIZES) @@ -100,12 +101,7 @@ def test_lora_laning_correctness(T, D, L, N, dtype, seed): r2 = ref_bgmv(r1, loras_b, idxs) o1 = torch.ops.xla.bgmv_shrink(inputs, loras_a, idxs) - o2 = torch.ops.xla.bgmv_expand( - o1, - loras_b.transpose(2, 3), - idxs, - True - ) + o2 = torch.ops.xla.bgmv_expand(o1, loras_b.transpose(2, 3), idxs, True) # Compare with reference output assert torch.allclose(o2, r2, rtol=1e-2, atol=1e-2) diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py index 387c8890ef67..d8b7a4232b24 100644 --- a/vllm/adapter_commons/utils.py +++ b/vllm/adapter_commons/utils.py @@ -55,8 +55,8 @@ def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], set_adapter_mapping_func) -> None: apply_adapters_func(requests) - # We need this to ensure that adapter loading/unloading and updating - # metadata are compiled as 2 separate graphs on TPU, otherwise we get + # We need this to ensure that adapter loading/unloading and updating + # metadata are compiled as 2 separate graphs on TPU, otherwise we get # runtime recompilations. if current_platform.is_tpu(): import torch_xla.core.xla_model as xm diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index b50ba7f382ac..ef9987d34c7c 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -311,7 +311,7 @@ def bgmv_shrink_non_xla(inputs: torch.Tensor, loras: torch.Tensor, loras = loras.squeeze(axis=1) N, L, _ = loras.shape - + LORA_BLOCK = 256 N_LORA_LANES = math.ceil(LORA_BLOCK / L) if N_LORA_LANES > 1 and N > 1: diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 500dd53b713a..789cdaa9c3f0 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -67,28 +67,26 @@ def shrink( ): return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - def expand( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_inputs: bool, - enable_laning: bool - ): - return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs, enable_laning=enable_laning) - - def expand_slice( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: int, - y_slice_size: int, - add_inputs: bool, - enable_laning: bool - ) -> torch.Tensor: - return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, - y_offset, y_slice_size, add_inputs, enable_laning=enable_laning) + def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, + add_inputs: bool, enable_laning: bool): + return bgmv_expand(x, + w_t_all, + y, + self.token_lora_indices, + add_inputs, + enable_laning=enable_laning) + + def expand_slice(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, + add_inputs: bool, enable_laning: bool) -> torch.Tensor: + return bgmv_expand_slice(x, + w_t_all, + y, + self.token_lora_indices, + y_offset, + y_slice_size, + add_inputs, + enable_laning=enable_laning) def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], @@ -157,15 +155,13 @@ def add_expand(self, y = self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): - y = self.expand_slice( - y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_left, - output_slices[slice_idx], - add_inputs=add_inputs, - enable_laning=kwargs["enable_laning"] - ) + y = self.expand_slice(y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + enable_laning=kwargs["enable_laning"]) offset_left += output_slices[slice_idx] return y.view_as(y_org) @@ -189,7 +185,11 @@ def add_lora_embedding(self, """ # Embedding layer only needs the expand op - return self.expand(y, x, lora_b_stacked, add_inputs, enable_laning=False) + return self.expand(y, + x, + lora_b_stacked, + add_inputs, + enable_laning=False) def add_lora_linear(self, y: torch.Tensor, @@ -276,7 +276,7 @@ def add_lora_logits(self, buffer (Optional[torch.Tensor]):Default to None. """ return y - + y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) From 4178e58271c7201df7d9bae4afb6e66509aa3a2f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 31 Mar 2025 20:02:57 +0000 Subject: [PATCH 137/186] Lint Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2df9afbfc752..84780a2d73d1 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -801,8 +801,10 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: inputs_embeds=inputs_embeds) self._hidden_states_dtype = out.dtype - def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, lora_requests) -> None: - super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, lora_requests) + def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, + lora_requests) -> None: + super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, + lora_requests) xm.mark_step() def capture_model(self) -> None: @@ -954,7 +956,7 @@ def sample_from_hidden( sampling_metadata: TPUSupportedSamplingMetadata, ) -> torch.Tensor: """ - Sample with xla-friendly function. This function is to be traced + Sample with xla-friendly function. This function is to be traced separately from `forward` for lighter compilation overhead. """ # Tensor `sample_hidden_states` is of fixed pre-compiled size. @@ -993,13 +995,13 @@ def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int: def _get_paddings(min_token_size: int, max_token_size: int, padding_gap: int) -> list[int]: - """Generate a list of padding size, starting from min_token_size, + """Generate a list of padding size, starting from min_token_size, ending with a number that can cover max_token_size - + If padding_gap == 0 then: increase 2X each time (exponential) else: - first increase the size to twice, + first increase the size to twice, then increase the padding size by padding_gap. """ paddings = [] From 8a3009de0e3c7ff7bf55ca2382245bdeaf867d83 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 31 Mar 2025 20:05:10 +0000 Subject: [PATCH 138/186] Added type annotation to lora_output Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 32eac38ecade..8b45d8b7b647 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -262,14 +262,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: -1, ) - lora_output = self.punica_wrapper.add_lora_embedding( - full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) - - # lora_output is None if the platform supports inplace updates. - # Otherwise it's a tensor, so we update the output manually + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_embedding( + full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + if not current_platform.can_update_inplace(): full_output = lora_output @@ -418,9 +417,10 @@ def apply(self, output = output.flatten(0, 1) x = x.flatten(0, 1) - lora_output = self.punica_wrapper.add_lora_linear( - output, x, self.lora_a_stacked, self.lora_b_stacked, - self.lora_bias_stacked, 1.0, self.output_slices) + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.lora_bias_stacked, 1.0, self.output_slices) if not current_platform.can_update_inplace(): output = lora_output @@ -1164,10 +1164,10 @@ def _get_logits( self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits - # LogitsProcessorWithLoRA always uses bgmv - lora_output = self.punica_wrapper.add_lora_logits( - logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, - 1.0) + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_logits( + logits, hidden_states, self.lora_a_stacked, + self.lora_b_stacked, 1.0) if not current_platform.can_update_inplace(): logits = lora_output From 1d6085a9d0994f7f65f613dd5d16eab7eb1a0b5c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 31 Mar 2025 22:20:59 +0000 Subject: [PATCH 139/186] Removed unused function/parameter Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 1 - vllm/lora/ops/xla_ops/pallas.py | 8 -------- vllm/lora/punica_wrapper/punica_tpu.py | 19 +++++-------------- 3 files changed, 5 insertions(+), 23 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index f5088c37745f..c89f4b137a12 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -34,7 +34,6 @@ def bgmv_expand(inputs: torch.Tensor, def bgmv_shrink(inputs: torch.Tensor, lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index ef9987d34c7c..669a35e45098 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -478,14 +478,6 @@ def bgmv_expand_non_xla(inputs: torch.Tensor, loras: torch.Tensor, return torch.empty((T, L), device=inputs.device) - -def largest_divisor(n: int, divs: List[int]) -> int: - for div in sorted(divs, reverse=True): - if n % div == 0: - return div - return max(divs) - - def next_multiple_of(n: int, mult: int) -> int: return math.ceil(n / mult) * mult diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 789cdaa9c3f0..9c7b59bace25 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -60,12 +60,11 @@ def sampler_indices_padded(self) -> torch.Tensor: def shrink( self, - y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float, ): - return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + return bgmv_shrink(x, w_t_all, self.token_lora_indices, scale) def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool, enable_laning: bool): @@ -115,7 +114,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], y_org = y_s y_s = y_s.view(-1, y_s.shape[-1]) - y_s = self.shrink(y_s, x, lora_s, scale) + y_s = self.shrink(x, lora_s, scale) y_s = y_s.view_as(y_org) new_y.append(y_s) return tuple(new_y) @@ -275,20 +274,11 @@ def add_lora_logits(self, scale (float): Scaling factor. buffer (Optional[torch.Tensor]):Default to None. """ - return y - y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) - rank = lora_b_stacked.size(-1) - if buffer is None: - buffer = torch.zeros((x.size(0), rank), - dtype=y.dtype, - device=x.device) - - buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, - scale) + buffer = bgmv_shrink(x, lora_a_stacked, self.sampler_indices, scale) y = bgmv_expand(buffer, lora_b_stacked, y, @@ -355,7 +345,8 @@ def _pad_prompt_mapping( self, prompt_mapping: Tuple[int, ...]) -> Tuple[int, ...]: num_reqs = len(prompt_mapping) - # From vllm/v1/worker/tppu_model_runner:52, but need to avoid a circular import + # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular + # import MIN_NUM_SEQS = 8 padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS) From f20823445b3d2baf0137a39bebe2fc8ad643f8ae Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 1 Apr 2025 00:01:40 +0000 Subject: [PATCH 140/186] Removed redundant padding in kernel for larger lora/dim sizes Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 669a35e45098..d82a8c878dfd 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -256,7 +256,7 @@ def bgmv_shrink_xla(inputs: torch.Tensor, loras: torch.Tensor, TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) LORA_BLOCK = 256 - DIM_BLOCK = min(1024, next_multiple_of(D, 256)) + DIM_BLOCK = largest_divisor(D, [256, 512, 1024]) # See if we can fit multiple LoRAs in a register. This would activate LoRA # laning @@ -419,7 +419,7 @@ def bgmv_expand_xla(inputs: torch.Tensor, loras: torch.Tensor, N, D, L = loras.shape TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) - LORA_BLOCK = min(1024, next_multiple_of(L, 256)) + LORA_BLOCK = largest_divisor(L, [256, 512, 1024]) DIM_BLOCK = 256 # See if we can fit multiple LoRAs in a register. This would activate LoRA @@ -478,6 +478,14 @@ def bgmv_expand_non_xla(inputs: torch.Tensor, loras: torch.Tensor, return torch.empty((T, L), device=inputs.device) + +def largest_divisor(n: int, divs: List[int]) -> int: + for div in sorted(divs, reverse=True): + if n % div == 0: + return div + return max(divs) + + def next_multiple_of(n: int, mult: int) -> int: return math.ceil(n / mult) * mult From ec0e1813479665e0653b5d9cac711a94d43937af Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 1 Apr 2025 18:09:26 +0000 Subject: [PATCH 141/186] Moved xm.mark_step() calls to move appropriate places Signed-off-by: Akshat Tripathi --- vllm/adapter_commons/utils.py | 10 -------- vllm/lora/models.py | 17 +++++++++++++ vllm/lora/punica_wrapper/punica_tpu.py | 35 ++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py index d8b7a4232b24..c2dc5433cc65 100644 --- a/vllm/adapter_commons/utils.py +++ b/vllm/adapter_commons/utils.py @@ -2,8 +2,6 @@ from typing import Any, Callable, Dict, Optional, Set -from vllm.platforms import current_platform - ## model functions def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], @@ -54,14 +52,6 @@ def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], apply_adapters_func, set_adapter_mapping_func) -> None: apply_adapters_func(requests) - - # We need this to ensure that adapter loading/unloading and updating - # metadata are compiled as 2 separate graphs on TPU, otherwise we get - # runtime recompilations. - if current_platform.is_tpu(): - import torch_xla.core.xla_model as xm - xm.mark_step() - set_adapter_mapping_func(mapping) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 8164d919ca8b..40d213f101d5 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -33,12 +33,27 @@ from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper +from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available logger = init_logger(__name__) _GLOBAL_LORA_ID = 0 +def maybe_compile_xla_graph(func): + if not current_platform.is_tpu(): + return func + + import torch_xla.core.xla_model as xm + + def wrapper(*args, **kwargs): + logger.warning("compiling with xla") + result = func(*args, **kwargs) + + xm.mark_step() + xm.wait_device_ops() + return result + return wrapper @dataclass class LongContextLoRAContext: @@ -378,6 +393,7 @@ def lora_slots(self) -> int: def adapter_slots(self) -> int: return self.lora_slots + @maybe_compile_xla_graph def activate_adapter( self, lora_id: int, @@ -628,6 +644,7 @@ def _register_packed_modules(self, module_full_name: str) -> None: prefix + "." + r if prefix else r for r in replacements ] + @maybe_compile_xla_graph def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: for module_name, new_module_names in self.packed_modules.items(): replacement_loras: List[Optional[LoRALayerWeights]] = [] diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 9c7b59bace25..65c42f472df2 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch +import torch.nn.functional as F +import torch_xla.core.xla_model as xm from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink from vllm.lora.punica_wrapper.utils import convert_mapping @@ -335,6 +337,7 @@ def _update_base_metadata( else: self._long_lora_indices.zero_() self.indices_len[:] = indices_len + xm.mark_step() def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 @@ -354,3 +357,35 @@ def _pad_prompt_mapping( padding = [-1] * pad_len return tuple(list(prompt_mapping) + padding) + +def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor): + selected_loras = loras[idxs] + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(axis=1) + + batch_size, output_size, input_size = selected_loras.shape + return (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) + +def bgmv_expand1(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, + *, + enable_laning: bool = False): + + outputs = ref_bgmv(inputs, lora_b_weights, lora_indices_tensor) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + if output_tensor.shape[1] > outputs.shape[1]: + outputs = F.pad(outputs, + (0, output_tensor.shape[1] - outputs.shape[1], 0, 0)) + + if add_inputs: + return output_tensor + outputs[:limit, :output_tensor.shape[1]] + else: + return outputs[:limit, :output_tensor.shape[1]] \ No newline at end of file From 38de473390261fbc04aa7f1b1a991a16adc4dcfe Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 1 Apr 2025 18:10:41 +0000 Subject: [PATCH 142/186] Reduced number of graphs compiled Signed-off-by: Akshat Tripathi --- vllm/v1/sample/tpu/metadata.py | 10 ++++++---- vllm/v1/worker/tpu_model_runner.py | 4 +++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 89d3ddf51d74..985f49f0381e 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -102,11 +102,8 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p, DEFAULT_SAMPLING_PARAMS["min_p"]) - xm.mark_step() - xm.wait_device_ops() - # Slice persistent device tensors to a fixed pre-compiled padded shape. - return cls( + input_batch = cls( temperature=input_batch.temperature[:padded_num_reqs], # Scalar tensor for xla-friendly tracing. all_greedy=torch.tensor(input_batch.all_greedy, @@ -118,3 +115,8 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, min_p=input_batch.min_p[:padded_num_reqs], generators=input_batch.generators, indices_do_sample=indices_do_sample) + + xm.mark_step() + xm.wait_device_ops() + + return input_batch \ No newline at end of file diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 84780a2d73d1..d0bc7ddb3bdb 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -784,6 +784,8 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: num_seqs=num_seqs, ) + xm.mark_step() # Capture tensors created when setting up + if self.is_multimodal_model: torch._dynamo.mark_dynamic(inputs_embeds, 0) else: @@ -805,7 +807,7 @@ def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, lora_requests) -> None: super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, lora_requests) - xm.mark_step() + xm.mark_step() # Captures metadata updates def capture_model(self) -> None: """Compile the model.""" From 8dabfab9bad9a0f9e04995cbf2612284b8a15ba0 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 1 Apr 2025 22:15:45 +0000 Subject: [PATCH 143/186] Fixed memory usage problem Signed-off-by: Akshat Tripathi --- vllm/lora/models.py | 2 +- vllm/v1/worker/lora_model_runner_mixin.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 40d213f101d5..1c39d5b41074 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -51,7 +51,6 @@ def wrapper(*args, **kwargs): result = func(*args, **kwargs) xm.mark_step() - xm.wait_device_ops() return result return wrapper @@ -412,6 +411,7 @@ def activate_adapter( logger.debug("Activating LoRA. int id: %d, slot index: %d", lora_model.id, index) self.lora_index_to_id[index] = lora_model.id + # return True for module_name, module in self.modules.items(): module_lora = self._get_lora_layer_weights(lora_model, module_name) if module_lora: diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index a8a19e0e6206..5fcc70695492 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -14,6 +14,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.platforms import current_platform from vllm.v1.worker.gpu_input_batch import InputBatch logger = init_logger(__name__) @@ -126,7 +127,10 @@ def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, yield # __exit__ code - self.lora_manager.remove_all_adapters() + # Disabling remove_all_adapters on the TPU backend allows us to save + # quite a bit of RAM. E.g. we save 2.22 GB with Llama3.1 8B + if not current_platform.is_tpu(): + self.lora_manager.remove_all_adapters() def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: From f5949a7e42a72c63ff035256196cc6621aa885e0 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 1 Apr 2025 22:19:17 +0000 Subject: [PATCH 144/186] Lint Signed-off-by: Akshat Tripathi --- vllm/lora/models.py | 12 +++++++----- vllm/v1/sample/tpu/metadata.py | 4 ++-- vllm/v1/worker/tpu_model_runner.py | 4 ++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 1c39d5b41074..04fdd26abc6c 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -40,20 +40,22 @@ _GLOBAL_LORA_ID = 0 + def maybe_compile_xla_graph(func): if not current_platform.is_tpu(): return func - + import torch_xla.core.xla_model as xm def wrapper(*args, **kwargs): - logger.warning("compiling with xla") result = func(*args, **kwargs) - + xm.mark_step() return result + return wrapper + @dataclass class LongContextLoRAContext: """Context for lora adapters that support long context.""" @@ -213,7 +215,7 @@ def from_local_checkpoint( weights_mapper: Optional[WeightsMapper] = None, ) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. - + Args: lora_dir: The local path that has lora data. expected_lora_modules: Name of modules that are expected to be @@ -621,7 +623,7 @@ def _match_target_modules(self, module_name: str): def _filter_unsupported_mm_module(self, module_name: str) -> bool: """ Regarding multimodal models, vLLM currently only supports adding LoRA to - language model. LoRA for other modules, such as the vision tower, will + language model. LoRA for other modules, such as the vision tower, will be filtered out. """ if self.supports_mm: diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 985f49f0381e..4defed6e5c92 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -118,5 +118,5 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, xm.mark_step() xm.wait_device_ops() - - return input_batch \ No newline at end of file + + return input_batch diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index d0bc7ddb3bdb..6926032369ee 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -784,7 +784,7 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None: num_seqs=num_seqs, ) - xm.mark_step() # Capture tensors created when setting up + xm.mark_step() # Capture tensors created when setting up if self.is_multimodal_model: torch._dynamo.mark_dynamic(inputs_embeds, 0) @@ -807,7 +807,7 @@ def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, lora_requests) -> None: super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, lora_requests) - xm.mark_step() # Captures metadata updates + xm.mark_step() # Captures metadata updates def capture_model(self) -> None: """Compile the model.""" From a7ae2881ff8cf24d1e2a23912999eec4160e519d Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 1 Apr 2025 22:25:32 +0000 Subject: [PATCH 145/186] Lint Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 7 +++++-- vllm/v1/sample/tpu/metadata.py | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index d82a8c878dfd..846393741fc2 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -12,6 +12,9 @@ from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, make_kernel_from_pallas) +# Ignore "Function definition does not bind loop variable" errors in Pallas +#ruff: noqa: B023 + XLA_LIB.define( "bgmv_shrink(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") @@ -19,8 +22,8 @@ # be the outputs of a LoRA laned bgmv_shrink. This is not always the case when # we use bgmv_expand XLA_LIB.define( - "bgmv_expand(Tensor inputs, Tensor loras, Tensor idxs, bool enable_laning) -> Tensor" -) + "bgmv_expand(Tensor inputs, Tensor loras, Tensor idxs, bool enable_laning) \ + -> Tensor") """ LoRA Laning Optimization for TPU Matrix Multiplication diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 4defed6e5c92..c1b49c56d0ac 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -103,7 +103,7 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, DEFAULT_SAMPLING_PARAMS["min_p"]) # Slice persistent device tensors to a fixed pre-compiled padded shape. - input_batch = cls( + tpu_sampling_metadata = cls( temperature=input_batch.temperature[:padded_num_reqs], # Scalar tensor for xla-friendly tracing. all_greedy=torch.tensor(input_batch.all_greedy, @@ -119,4 +119,4 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, xm.mark_step() xm.wait_device_ops() - return input_batch + return tpu_sampling_metadata From 2e67aa8be8cd2758c7bdf74542409f383cdcf483 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 1 Apr 2025 22:42:58 +0000 Subject: [PATCH 146/186] Removed first inference recompilation Signed-off-by: Akshat Tripathi --- vllm/lora/models.py | 18 ------------ vllm/lora/punica_wrapper/punica_tpu.py | 38 +++----------------------- 2 files changed, 4 insertions(+), 52 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 04fdd26abc6c..ca1cb2d3fa01 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -33,7 +33,6 @@ from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper -from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -41,21 +40,6 @@ _GLOBAL_LORA_ID = 0 -def maybe_compile_xla_graph(func): - if not current_platform.is_tpu(): - return func - - import torch_xla.core.xla_model as xm - - def wrapper(*args, **kwargs): - result = func(*args, **kwargs) - - xm.mark_step() - return result - - return wrapper - - @dataclass class LongContextLoRAContext: """Context for lora adapters that support long context.""" @@ -394,7 +378,6 @@ def lora_slots(self) -> int: def adapter_slots(self) -> int: return self.lora_slots - @maybe_compile_xla_graph def activate_adapter( self, lora_id: int, @@ -646,7 +629,6 @@ def _register_packed_modules(self, module_full_name: str) -> None: prefix + "." + r if prefix else r for r in replacements ] - @maybe_compile_xla_graph def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: for module_name, new_module_names in self.packed_modules.items(): replacement_loras: List[Optional[LoRALayerWeights]] = [] diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 65c42f472df2..08b517a87f5c 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch_xla.core.xla_model as xm from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink @@ -300,6 +299,9 @@ def _update_base_metadata( extra_vocab_size: int, long_lora_context: Optional["LongContextLoRAContext"] = None, ): + # Make sure we don't accidentally collect outside operations + xm.mark_step() + # Pad the prompt mapping to avoid running into recompiles on the TPU # TODO: Should this happen inside mapping internally? If so how can we # avoid having backend specific LoRAMapping classes? @@ -348,7 +350,7 @@ def _pad_prompt_mapping( self, prompt_mapping: Tuple[int, ...]) -> Tuple[int, ...]: num_reqs = len(prompt_mapping) - # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular + # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular # import MIN_NUM_SEQS = 8 @@ -357,35 +359,3 @@ def _pad_prompt_mapping( padding = [-1] * pad_len return tuple(list(prompt_mapping) + padding) - -def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor): - selected_loras = loras[idxs] - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(axis=1) - - batch_size, output_size, input_size = selected_loras.shape - return (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) - -def bgmv_expand1(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True, - *, - enable_laning: bool = False): - - outputs = ref_bgmv(inputs, lora_b_weights, lora_indices_tensor) - - limit = output_tensor.shape[0] - if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: - limit = 1 - - if output_tensor.shape[1] > outputs.shape[1]: - outputs = F.pad(outputs, - (0, output_tensor.shape[1] - outputs.shape[1], 0, 0)) - - if add_inputs: - return output_tensor + outputs[:limit, :output_tensor.shape[1]] - else: - return outputs[:limit, :output_tensor.shape[1]] \ No newline at end of file From 27b3c52987aaba59f9bf5075f260b55ae07a03c5 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 1 Apr 2025 23:15:01 +0000 Subject: [PATCH 147/186] Fixed more recompilations Signed-off-by: Akshat Tripathi --- vllm/lora/models.py | 1 - vllm/v1/sample/tpu/metadata.py | 10 ++++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index ca1cb2d3fa01..8b777cdc086b 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -396,7 +396,6 @@ def activate_adapter( logger.debug("Activating LoRA. int id: %d, slot index: %d", lora_model.id, index) self.lora_index_to_id[index] = lora_model.id - # return True for module_name, module in self.modules.items(): module_lora = self._get_lora_layer_weights(lora_model, module_name) if module_lora: diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index c1b49c56d0ac..9b7a54d7476e 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -101,9 +101,12 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, # copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k) copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p, DEFAULT_SAMPLING_PARAMS["min_p"]) + + xm.mark_step() + xm.wait_device_ops() # Slice persistent device tensors to a fixed pre-compiled padded shape. - tpu_sampling_metadata = cls( + return cls( temperature=input_batch.temperature[:padded_num_reqs], # Scalar tensor for xla-friendly tracing. all_greedy=torch.tensor(input_batch.all_greedy, @@ -115,8 +118,3 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, min_p=input_batch.min_p[:padded_num_reqs], generators=input_batch.generators, indices_do_sample=indices_do_sample) - - xm.mark_step() - xm.wait_device_ops() - - return tpu_sampling_metadata From d1452af3cd785a8051cb8bbeab50bdc4ea55c905 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 1 Apr 2025 23:27:59 +0000 Subject: [PATCH 148/186] Added flag to disabled add_lora_logits() Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 08b517a87f5c..8dab65c803cf 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -17,8 +17,6 @@ from .punica_base import PunicaWrapperBase -# The platforms that are compatible with the PyTorch-native implementation can -# inherit this class class PunicaWrapperTPU(PunicaWrapperBase): """ PunicaWrapperTPU is designed to manage and provide metadata for the punica @@ -39,6 +37,9 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) self._sampler_indices_padded = self._sampler_indices_padded.to( dtype=torch.int32) + + # Debug only + self.disable_add_lora_logits = True def mark_compiled(self): torch._dynamo.mark_dynamic(self._embeddings_indices, 1) @@ -275,6 +276,8 @@ def add_lora_logits(self, scale (float): Scaling factor. buffer (Optional[torch.Tensor]):Default to None. """ + if self.disable_add_lora_logits: + return y y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) From 1cc89a57363cf458ae073c00dac7eb3e4e1cdc16 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 1 Apr 2025 23:38:26 +0000 Subject: [PATCH 149/186] Lint Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 2 +- vllm/v1/sample/tpu/metadata.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 8dab65c803cf..320829323222 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -37,7 +37,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) self._sampler_indices_padded = self._sampler_indices_padded.to( dtype=torch.int32) - + # Debug only self.disable_add_lora_logits = True diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 9b7a54d7476e..89d3ddf51d74 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -101,7 +101,7 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, # copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k) copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p, DEFAULT_SAMPLING_PARAMS["min_p"]) - + xm.mark_step() xm.wait_device_ops() From 93d3e8f925ec04b590c306c4d70632864694c672 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 2 Apr 2025 04:30:12 +0000 Subject: [PATCH 150/186] Fixed performance issue where the sampler would face long stalls Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 5 ----- vllm/v1/worker/tpu_model_runner.py | 18 ++++++++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 320829323222..c604cbbbf2cb 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -38,9 +38,6 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self._sampler_indices_padded = self._sampler_indices_padded.to( dtype=torch.int32) - # Debug only - self.disable_add_lora_logits = True - def mark_compiled(self): torch._dynamo.mark_dynamic(self._embeddings_indices, 1) torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) @@ -276,8 +273,6 @@ def add_lora_logits(self, scale (float): Scaling factor. buffer (Optional[torch.Tensor]):Default to None. """ - if self.disable_add_lora_logits: - return y y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 6926032369ee..7c283600b0e1 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -179,6 +179,12 @@ def __init__( max_token_size=self.max_num_tokens, padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + if self.lora_config is not None: + # This makes us pad at initialisation time so we can avoid padding + # at runtime, which introduces long stalls + self.lora_config.max_lora_rank = _get_padded_lora_rank( + self.lora_config.max_lora_rank, self.lora_config.max_loras) + def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: """Update the cached states and the persistent batch with the scheduler output. @@ -1039,6 +1045,18 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: return paddings[index] +def _get_padded_lora_rank(max_lora_rank: int, max_num_loras: int) -> int: + LORA_BLOCK_SIZE = 256 # Same as in the pallas kernel + + max_num_loras += 1 + + # If we have enough LoRAs to use laning without padding + if max_lora_rank * max_num_loras >= LORA_BLOCK_SIZE: + return max_lora_rank + + return 1 << (LORA_BLOCK_SIZE // max_num_loras).bit_length() + + def _create_dummy_scheduled_tokens(total_tokens: int, num_prompts: int) -> np.ndarray: assert num_prompts <= total_tokens, "Expected num_prompts < total_tokens" From e1aaed6344df79d8e8c01a2408de69d49aafbf0b Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 2 Apr 2025 17:04:14 +0000 Subject: [PATCH 151/186] Fixed laning integration bug Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index c604cbbbf2cb..7ab8bf7c3318 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -16,7 +16,6 @@ from .punica_base import PunicaWrapperBase - class PunicaWrapperTPU(PunicaWrapperBase): """ PunicaWrapperTPU is designed to manage and provide metadata for the punica @@ -102,19 +101,12 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights scale (float): Scaling factor for the operation """ - x = x.view(-1, x.shape[-1]) new_y = [] for slice_idx in range(len(lora_a_stacked)): - y_s = y[slice_idx] lora_s = lora_a_stacked[slice_idx] - - y_org = y_s - y_s = y_s.view(-1, y_s.shape[-1]) - y_s = self.shrink(x, lora_s, scale) - y_s = y_s.view_as(y_org) new_y.append(y_s) return tuple(new_y) From 62500e19724f881bb6b62ab3081d7e4127e586d7 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 2 Apr 2025 17:11:50 +0000 Subject: [PATCH 152/186] Lint Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 7ab8bf7c3318..1893fcf32f47 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -16,6 +16,7 @@ from .punica_base import PunicaWrapperBase + class PunicaWrapperTPU(PunicaWrapperBase): """ PunicaWrapperTPU is designed to manage and provide metadata for the punica From eb72ab659670326e80a8b6d73c19c7e2a3bf2f44 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 4 Apr 2025 13:01:38 +0000 Subject: [PATCH 153/186] Removed LoRA vocab padding for TPU Signed-off-by: Akshat Tripathi --- vllm/config.py | 7 ++++--- vllm/platforms/interface.py | 11 ++++++++++- vllm/platforms/tpu.py | 4 ++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ba20e3fd7512..49ff23c3bb3d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) from vllm.model_executor.models import ModelRegistry -from vllm.platforms import CpuArchEnum +from vllm.platforms import CpuArchEnum, current_platform from vllm.sampling_params import GuidedDecodingParams from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import ( @@ -2382,8 +2382,8 @@ class LoRAConfig: max_cpu_loras: Optional[int] = None lora_dtype: Optional[Union[torch.dtype, str]] = None lora_extra_vocab_size: int = 256 - # This is a constant. - lora_vocab_padding_size: ClassVar[int] = 256 + lora_vocab_padding_size: ClassVar[int] = current_platform\ + .get_lora_vocab_padding_size() long_lora_scaling_factors: Optional[tuple[float]] = None bias_enabled: bool = False @@ -2405,6 +2405,7 @@ def compute_hash(self) -> str: factors.append(self.fully_sharded_loras) factors.append(self.lora_dtype) factors.append(self.lora_extra_vocab_size) + factors.append(self.lora_vocab_padding_size) factors.append(self.long_lora_scaling_factors) factors.append(self.bias_enabled) hash_str = hashlib.md5(str(factors).encode(), diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 3816d29e6835..1e78bf6196b6 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -331,9 +331,18 @@ def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: @classmethod def can_update_inplace(cls) -> bool: - """Checks if the platform allows inplace memory updates""" + """ + Checks if the platform allows inplace memory updates + """ return True + @classmethod + def get_lora_vocab_padding_size(cls) -> int: + """ + Returns how much padding the LoRA logits need for kernels + """ + return 256 + @classmethod def get_device_communicator_cls(cls) -> str: """ diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 633afa22c5c1..2b7d4c45761a 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -74,6 +74,10 @@ def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: def can_update_inplace(cls): return False + @classmethod + def get_lora_vocab_padding_size(cls) -> int: + return 0 + @classmethod def inference_mode(cls): return torch.no_grad() From 54db22dbe56242f1131b644571a01fe74e943136 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 4 Apr 2025 14:02:59 +0000 Subject: [PATCH 154/186] Fixed 0 padding issue with LoRA Signed-off-by: Akshat Tripathi --- vllm/model_executor/layers/vocab_parallel_embedding.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 1eb0c8c2ef4e..2a5e9395a35d 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -50,6 +50,8 @@ def embedding(self, layer: torch.nn.Module, def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" + if pad_to == 0: + return vocab_size return ((vocab_size + pad_to - 1) // pad_to) * pad_to From 5232785604ab6e435cca1c09767ccb12e77cb2a2 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 4 Apr 2025 14:06:36 +0000 Subject: [PATCH 155/186] Changed TPU lora_vocab_padding_size to 1 Signed-off-by: Akshat Tripathi --- vllm/model_executor/layers/vocab_parallel_embedding.py | 2 -- vllm/platforms/tpu.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 2a5e9395a35d..1eb0c8c2ef4e 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -50,8 +50,6 @@ def embedding(self, layer: torch.nn.Module, def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" - if pad_to == 0: - return vocab_size return ((vocab_size + pad_to - 1) // pad_to) * pad_to diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 2b7d4c45761a..d837c982bdb8 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -76,7 +76,7 @@ def can_update_inplace(cls): @classmethod def get_lora_vocab_padding_size(cls) -> int: - return 0 + return 1 @classmethod def inference_mode(cls): From 1b4c2f2b19c04acf5c538d12bdd5feba783b7599 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 4 Apr 2025 16:25:35 +0000 Subject: [PATCH 156/186] Fixed bug in bgmv_expand kernel - outputs weren't being written with just 1 lora Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 846393741fc2..2b54ba7ff20c 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -361,9 +361,9 @@ def _(): lora_ref[i, ...], preferred_element_type=jnp.float32) * mask_ref[...] - @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) - def _(): - out_ref[...] = acc_ref[...].astype(out_ref.dtype) + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + def _(): + out_ref[...] = acc_ref[...].astype(out_ref.dtype) @functools.partial(jax.jit, From c8f68d7041c8a12873fde5664505ee62d86d87e4 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 4 Apr 2025 14:06:36 +0000 Subject: [PATCH 157/186] Changed TPU lora_vocab_padding_size to 1 Signed-off-by: Akshat Tripathi --- vllm/platforms/tpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 2b7d4c45761a..d837c982bdb8 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -76,7 +76,7 @@ def can_update_inplace(cls): @classmethod def get_lora_vocab_padding_size(cls) -> int: - return 0 + return 1 @classmethod def inference_mode(cls): From ed3b245134dfc184932f11735ad0c08c0700ae55 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 4 Apr 2025 22:10:36 +0000 Subject: [PATCH 158/186] Enabled lora bias Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 87 +++++++++++++++++--------- 1 file changed, 58 insertions(+), 29 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 6b2ccf9148ab..397fc90e905a 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple, Union import torch +import torch.nn.functional as F from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink @@ -33,9 +34,13 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, dtype=torch.int32) def mark_compiled(self): + torch._dynamo.mark_dynamic(self._token_lora_indices, 0) torch._dynamo.mark_dynamic(self._embeddings_indices, 1) torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) + def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: + return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) + @property def embeddings_indices(self) -> torch.Tensor: """ @@ -60,33 +65,20 @@ def shrink( ): if self.no_lora: return y - return bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + return bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x), + scale) - def expand( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_inputs: bool, - ): - if self.no_lora: - return y - return bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) + def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, + add_inputs: bool): + return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), + add_inputs) - def expand_slice( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: int, - y_slice_size: int, - y_total_size: int, - add_inputs: bool, - ) -> torch.Tensor: - if self.no_lora: - return y - return bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, - y_offset, y_slice_size, add_inputs) + def expand_slice(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, + y_total_size: int, add_inputs: bool) -> torch.Tensor: + return bgmv_expand_slice(x, w_t_all, y, + self._get_token_lora_indices(x), y_offset, + y_slice_size, add_inputs) def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], @@ -152,9 +144,10 @@ def add_expand(self, y_org = y y = y.view(-1, y.shape[-1]) offset_left = 0 + if lora_bias_stacked is not None: - y = self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) + y = self._apply_bias(self._get_token_lora_indices(y), y, + output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): y = self.expand_slice( y, @@ -227,8 +220,8 @@ def add_lora_linear(self, assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) if lora_bias_stacked is not None: assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias(self.token_lora_indices, y, output_slices, - lora_bias_stacked) + y = self._apply_bias(self._get_token_lora_indices(y), y, + output_slices, lora_bias_stacked) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -296,6 +289,42 @@ def add_lora_logits(self, add_inputs=True) return y.view_as(y_org) + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias = torch.where(indices[:, None] == -1, 0, bias) + + bias = F.pad(bias, (offset_left, output.shape[1] - + (offset_left + slice), 0, 0)) + + output += bias + offset_left += slice + + return output.view_as(org_output) + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 self._lora_indices_per_batch[:self.batch_size].copy_( From 9d35414ca0d7ee07c6a0740429e56d7778cb2731 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 7 Apr 2025 13:46:02 +0000 Subject: [PATCH 159/186] Replaced `enable_laning` flag with dim comparison Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 18 +++++-------- vllm/lora/ops/xla_ops/lora_ops.py | 12 +++------ vllm/lora/ops/xla_ops/pallas.py | 11 ++++---- vllm/lora/punica_wrapper/punica_tpu.py | 36 +++++++------------------- 4 files changed, 25 insertions(+), 52 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index 5d9b64386db2..f50922d4d202 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -5,18 +5,12 @@ # Required to register the custom ops import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import -# N_TOKENS = [ -# 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, -# 131072 -# ] -# HIDDEN_SIZES = [128, 256, 512, 896, 1024, 2048, 4096, 8192, 8320] +N_TOKENS = [8, 16, 64, 2048] +HIDDEN_SIZES = [2048] -# DTYPES = [torch.float16, torch.bfloat16] -# NUM_LORA = [1, 2, 4, 8, 16, 32] -# RANKS = [8, 16, 32, 64, 128] - -N_TOKENS = [2048] -HIDDEN_SIZES = [4096] +DTYPES = [torch.bfloat16] +NUM_LORA = [1, 2, 4] +RANKS = [32, 256] DTYPES = [torch.bfloat16] NUM_LORA = [1, 2, 4] @@ -101,7 +95,7 @@ def test_lora_laning_correctness(T, D, L, N, dtype, seed): r2 = ref_bgmv(r1, loras_b, idxs) o1 = torch.ops.xla.bgmv_shrink(inputs, loras_a, idxs) - o2 = torch.ops.xla.bgmv_expand(o1, loras_b.transpose(2, 3), idxs, True) + o2 = torch.ops.xla.bgmv_expand(o1, loras_b.transpose(2, 3), idxs) # Compare with reference output assert torch.allclose(o2, r2, rtol=1e-2, atol=1e-2) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index c89f4b137a12..96ca63e73d01 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -11,12 +11,10 @@ def bgmv_expand(inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, - add_inputs: bool = True, - *, - enable_laning: bool = False): + add_inputs: bool = True): outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3), - lora_indices_tensor, enable_laning) + lora_indices_tensor) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: @@ -47,11 +45,9 @@ def bgmv_expand_slice(inputs: torch.Tensor, lora_indices_tensor: torch.Tensor, slice_offset: int, slice_size: int, - add_inputs: bool = True, - *, - enable_laning: bool = False): + add_inputs: bool = True): outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3), - lora_indices_tensor, enable_laning) + lora_indices_tensor) outputs = F.pad(outputs, (slice_offset, output_tensor.shape[1] - (slice_offset + slice_size), 0, 0)) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 2b54ba7ff20c..97f55187b851 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -22,8 +22,7 @@ # be the outputs of a LoRA laned bgmv_shrink. This is not always the case when # we use bgmv_expand XLA_LIB.define( - "bgmv_expand(Tensor inputs, Tensor loras, Tensor idxs, bool enable_laning) \ - -> Tensor") + "bgmv_expand(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") """ LoRA Laning Optimization for TPU Matrix Multiplication @@ -412,13 +411,13 @@ def bgmv_expand_shape_function(idxs, inputs, loras): @impl(XLA_LIB, "bgmv_expand", "XLA") def bgmv_expand_xla(inputs: torch.Tensor, loras: torch.Tensor, - idxs: torch.IntTensor, enable_laning: bool): + idxs: torch.IntTensor): inputs = inputs.to(dtype=loras.dtype) if len(loras.shape) == 4: loras = loras.squeeze(axis=1) - T, _ = inputs.shape + T, DI = inputs.shape N, D, L = loras.shape TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) @@ -428,7 +427,7 @@ def bgmv_expand_xla(inputs: torch.Tensor, loras: torch.Tensor, # See if we can fit multiple LoRAs in a register. This would activate LoRA # laning N_LORA_LANES = math.ceil(DIM_BLOCK / D) - if enable_laning and N_LORA_LANES > 1 and N > 1: + if D != DI and N_LORA_LANES > 1 and N > 1: pad_N = next_multiple_of(N, N_LORA_LANES) - N new_N = N + pad_N @@ -471,7 +470,7 @@ def bgmv_expand_xla(inputs: torch.Tensor, loras: torch.Tensor, @impl(XLA_LIB, "bgmv_expand", "CompositeExplicitAutograd") def bgmv_expand_non_xla(inputs: torch.Tensor, loras: torch.Tensor, - idxs: torch.IntTensor, enable_laning: bool): + idxs: torch.IntTensor): T, _ = inputs.shape if len(loras.shape) == 4: diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index c3077ff68b0a..77c235a83b39 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -71,25 +71,16 @@ def shrink( return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale) def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, - add_inputs: bool, enable_laning: bool): - return bgmv_expand(x, - w_t_all, - y, - self._get_token_lora_indices(x), - add_inputs, - enable_laning=enable_laning) + add_inputs: bool): + return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), + add_inputs) def expand_slice(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, - add_inputs: bool, enable_laning: bool) -> torch.Tensor: - return bgmv_expand_slice(x, - w_t_all, - y, - self._get_token_lora_indices(x), - y_offset, - y_slice_size, - add_inputs, - enable_laning=enable_laning) + add_inputs: bool) -> torch.Tensor: + return bgmv_expand_slice(x, w_t_all, y, + self._get_token_lora_indices(x), y_offset, + y_slice_size, add_inputs) def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], @@ -157,8 +148,7 @@ def add_expand(self, lora_b_stacked[slice_idx], offset_left, output_slices[slice_idx], - add_inputs=add_inputs, - enable_laning=kwargs["enable_laning"]) + add_inputs=add_inputs) offset_left += output_slices[slice_idx] return y.view_as(y_org) @@ -182,11 +172,7 @@ def add_lora_embedding(self, """ # Embedding layer only needs the expand op - return self.expand(y, - x, - lora_b_stacked, - add_inputs, - enable_laning=False) + return self.expand(y, x, lora_b_stacked, add_inputs) def add_lora_linear(self, y: torch.Tensor, @@ -245,7 +231,6 @@ def add_lora_linear(self, None, output_slices, add_inputs=True, - enable_laning=True, **kwargs) def add_lora_logits(self, @@ -281,8 +266,7 @@ def add_lora_logits(self, lora_b_stacked, y, self.sampler_indices, - add_inputs=True, - enable_laning=True) + add_inputs=True) return y.view_as(y_org) def _apply_bias( From 54c00c3846cd7c0a0f1c90989839964e01c3c511 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 7 Apr 2025 13:51:55 +0000 Subject: [PATCH 160/186] Enabled fully sharded loras Signed-off-by: Akshat Tripathi --- vllm/lora/fully_sharded_layers.py | 39 +++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index 41e1ec94145d..e195f8cf5e8e 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -16,6 +16,7 @@ MergedQKVParallelLinearWithLoRA, QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA) +from vllm.platforms import current_platform if TYPE_CHECKING: pass @@ -57,15 +58,25 @@ def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): device=x.device, ) - layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0) + shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( + buffers, x, layer.lora_a_stacked, 1.0) + + if not current_platform.can_update_inplace(): + buffers = shrunk_buffers + buffers = tensor_model_parallel_all_gather(buffers) - layer.punica_wrapper.add_expand(output, - buffers, - layer.lora_b_stacked, - layer.lora_bias_stacked, - layer.output_slices, - offset_start=0, - add_input=True) + + lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( + output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + if not current_platform.can_update_inplace(): + output = lora_output output = output.view(*out_orig_shape) # now have column partitioned and packed output @@ -292,7 +303,11 @@ def apply(self, device=x.device, ) - self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) + shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( + buffer, x, self.lora_a_stacked, 1.0) + if not current_platform.can_update_inplace(): + buffer = shrunk_buffer + buffer = tensor_model_parallel_all_reduce(buffer) # following S-LoRA, allows the fusing of all_gather and all_reduce @@ -304,7 +319,7 @@ def apply(self, # NOTE offset are based on the rank. shard_size = self.lora_b_stacked[0].shape[2] offset_start = self.tp_rank * shard_size - self.punica_wrapper.add_expand( + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( output, buffer, self.lora_b_stacked, @@ -313,6 +328,10 @@ def apply(self, offset_start=offset_start, add_input=True, ) + + if not current_platform.can_update_inplace(): + output = lora_output + output = output.view(*out_orig_shape) return output From a4b2e278a975425419bc7f961a94ba2da7a4cc44 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 7 Apr 2025 15:33:57 +0000 Subject: [PATCH 161/186] Removed test benchmarking file Signed-off-by: Akshat Tripathi --- bmark_kernels.py | 113 ----------------------------------------------- 1 file changed, 113 deletions(-) delete mode 100644 bmark_kernels.py diff --git a/bmark_kernels.py b/bmark_kernels.py deleted file mode 100644 index 80a022da8f40..000000000000 --- a/bmark_kernels.py +++ /dev/null @@ -1,113 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import itertools -from functools import partial - -import pytest -import torch -import torch_xla.core.xla_model as xm - - -def create_tensors(T, D, L, N, dtype=torch.bfloat16, device='xla'): - """ - Inputs: (All integers) - T: Total number of tokens - D: Input dim - L: LoRA Dim - N: N LoRAs - - Outputs: - inputs: torch.Tensor - shape (T, D) - loras: torch.Tensor - shape (N, 1, L, D) - idxs: torch.IntTensor - shape (T, ) - all values must be in [0, N) - """ - - inputs = torch.randn((T, D), dtype=dtype, device=device) - loras = torch.randn((N, L, D), dtype=dtype, device=device) - idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device=device) - - return inputs, loras, idxs - - -SEQ_LENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 131072] -HIDDEN_DIM = [256, 1024, 4096, 8192, 14336, 28672] -LORA_RANKS = [8, 16, 32, 64, 128, 128] -N_LORAS = [1, 2, 4, 8] - - -@torch.compile(fullgraph=True, dynamic=False, backend="openxla") -def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): - return torch.einsum("td,tld->tl", inputs, loras[idxs]) - - -@torch.compile(fullgraph=True, dynamic=False, backend="openxla") -def bgmv_shrink(inputs: torch.Tensor, loras: torch.Tensor, - idxs: torch.IntTensor): - return torch.ops.xla.bgmv_shrink(inputs, loras, idxs) - - -@torch.compile(fullgraph=True, dynamic=False, backend="openxla") -def bgmv_expand(inputs: torch.Tensor, loras: torch.Tensor, - idxs: torch.IntTensor, enable_laning: bool): - return torch.ops.xla.bgmv_expand(inputs, loras, idxs, enable_laning) - - -@torch.compile(fullgraph=True, dynamic=False, backend="openxla") -def shrink_and_expand(inputs: torch.Tensor, loras_a: torch.Tensor, - loras_b: torch.Tensor, idxs: torch.IntTensor): - return bgmv_expand(bgmv_shrink(inputs, loras_a, idxs), - loras_b, - idxs, - enable_laning=True) - - -@torch.compile(fullgraph=True, dynamic=False, backend="openxla") -def ref_shrink_and_expand(inputs: torch.Tensor, loras_a: torch.Tensor, - loras_b: torch.Tensor, idxs: torch.IntTensor): - return ref_bgmv(ref_bgmv(inputs, loras_a, idxs), loras_b, idxs) - - -def run_and_wait_torch(func, *args): - out = func(*args) - xm.mark_step() - xm.wait_device_ops() - return out - - -@pytest.mark.parametrize( - "T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) -@pytest.mark.parametrize("func", [bgmv_shrink]) -def test_bmark_shrink(benchmark, T, D, L, N, func): - inputs, loras, idxs = create_tensors(T, D, L, N) - - benchmark.pedantic(partial(run_and_wait_torch, func), - args=(inputs, loras, idxs), - rounds=5, - warmup_rounds=5, - iterations=10) - - -@pytest.mark.parametrize( - "T,D,L,N", itertools.product(SEQ_LENS, LORA_RANKS, HIDDEN_DIM, N_LORAS)) -@pytest.mark.parametrize("func", [bgmv_expand]) -def test_bmark_expand(benchmark, T, D, L, N, func): - inputs, loras, idxs = create_tensors(T, D, L, N) - - benchmark.pedantic(partial(run_and_wait_torch, func), - args=(inputs, loras, idxs), - rounds=5, - warmup_rounds=5, - iterations=10) - - -@pytest.mark.parametrize( - "T,D,L,N", itertools.product(SEQ_LENS, HIDDEN_DIM, LORA_RANKS, N_LORAS)) -@pytest.mark.parametrize("func", [shrink_and_expand]) -def test_bmark_shrink_and_expand(benchmark, T, D, L, N, func): - inputs, loras_a, idxs = create_tensors(T, D, L, N) - _, loras_b, _ = create_tensors(T, L, D, N) - - benchmark.pedantic(partial(run_and_wait_torch, func), - args=(inputs, loras_a, loras_b, idxs), - rounds=5, - warmup_rounds=5, - iterations=10) From fbddd3ce4313e88fc8b6329ae1c0c91d60c9ab6a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 8 Apr 2025 16:20:52 +0000 Subject: [PATCH 162/186] Refactored add_shrink to return a tensor not a tuple Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/__init__.py | 5 ++++- vllm/lora/ops/xla_ops/pallas.py | 8 +++++--- vllm/lora/punica_wrapper/punica_tpu.py | 15 +++++++-------- vllm/v1/worker/tpu_model_runner.py | 7 +++---- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py index 94062b05d916..2b6337d8fd8f 100644 --- a/vllm/lora/ops/xla_ops/__init__.py +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -2,5 +2,8 @@ from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink) +from vllm.lora.ops.xla_ops.pallas import LORA_RANK_BLOCK_SIZE -__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] +__all__ = [ + "bgmv_expand", "bgmv_expand_slice", "bgmv_shrink", "LORA_RANK_BLOCK_SIZE" +] diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 97f55187b851..145c1d363774 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -145,6 +145,8 @@ """ +LORA_RANK_BLOCK_SIZE = 256 + def _bgmv_shrink_kernel(bT: int, bL: int, n_lora_lanes: int, lane_size: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, @@ -257,7 +259,7 @@ def bgmv_shrink_xla(inputs: torch.Tensor, loras: torch.Tensor, N, L, D = loras.shape TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) - LORA_BLOCK = 256 + LORA_BLOCK = LORA_RANK_BLOCK_SIZE DIM_BLOCK = largest_divisor(D, [256, 512, 1024]) # See if we can fit multiple LoRAs in a register. This would activate LoRA @@ -314,7 +316,7 @@ def bgmv_shrink_non_xla(inputs: torch.Tensor, loras: torch.Tensor, N, L, _ = loras.shape - LORA_BLOCK = 256 + LORA_BLOCK = LORA_RANK_BLOCK_SIZE N_LORA_LANES = math.ceil(LORA_BLOCK / L) if N_LORA_LANES > 1 and N > 1: L = LORA_BLOCK @@ -422,7 +424,7 @@ def bgmv_expand_xla(inputs: torch.Tensor, loras: torch.Tensor, TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) LORA_BLOCK = largest_divisor(L, [256, 512, 1024]) - DIM_BLOCK = 256 + DIM_BLOCK = LORA_RANK_BLOCK_SIZE # See if we can fit multiple LoRAs in a register. This would activate LoRA # laning diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 77c235a83b39..f6059072ca12 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -7,7 +7,8 @@ import torch.nn.functional as F import torch_xla.core.xla_model as xm -from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink +from vllm.lora.ops.xla_ops import (LORA_RANK_BLOCK_SIZE, bgmv_expand, + bgmv_expand_slice, bgmv_shrink) from vllm.lora.punica_wrapper.utils import convert_mapping if TYPE_CHECKING: @@ -98,14 +99,14 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights scale (float): Scaling factor for the operation """ + torch.ops.xla.dynamo_set_buffer_donor_(y, True) x = x.view(-1, x.shape[-1]) - new_y = [] for slice_idx in range(len(lora_a_stacked)): lora_s = lora_a_stacked[slice_idx] y_s = self.shrink(x, lora_s, scale) - new_y.append(y_s) - return tuple(new_y) + y[slice_idx, :, :] = y_s + return y def add_expand(self, y: torch.Tensor, @@ -215,13 +216,11 @@ def add_lora_linear(self, output_slices, lora_bias_stacked) if buffer is None: - r = lora_b_stacked[0].size(-1) - # We set the buffer to be float32 by default, consistent with the - # triton op + r = max(lora_b_stacked[0].size(-1), LORA_RANK_BLOCK_SIZE) T = x.size(0) buffer = torch.zeros( (len(output_slices), T, r), - dtype=torch.float32, + dtype=x.dtype, device=x.device, ) buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 76f5e27ad7a4..065acc437803 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -18,6 +18,7 @@ from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.lora.ops.xla_ops import LORA_RANK_BLOCK_SIZE from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import group_mm_inputs_by_modality @@ -1084,15 +1085,13 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: def _get_padded_lora_rank(max_lora_rank: int, max_num_loras: int) -> int: - LORA_BLOCK_SIZE = 256 # Same as in the pallas kernel - max_num_loras += 1 # If we have enough LoRAs to use laning without padding - if max_lora_rank * max_num_loras >= LORA_BLOCK_SIZE: + if max_lora_rank * max_num_loras >= LORA_RANK_BLOCK_SIZE: return max_lora_rank - return 1 << (LORA_BLOCK_SIZE // max_num_loras).bit_length() + return 1 << (LORA_RANK_BLOCK_SIZE // max_num_loras).bit_length() def _create_dummy_scheduled_tokens(total_tokens: int, From 1803135007e2caa167381296fcde7039005d6aed Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 9 Apr 2025 11:23:14 +0000 Subject: [PATCH 163/186] Removed tuple return in add_shrink() Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 397fc90e905a..0b6867983db9 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -44,7 +44,7 @@ def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: @property def embeddings_indices(self) -> torch.Tensor: """ - This property provides access to the indices used for lora embeddings, + This property provides access to the indices used for lora embeddings, specifically for VocabParallelEmbeddingWithLoRA. """ return self._embeddings_indices[:] @@ -97,21 +97,16 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], scale (float): Scaling factor for the operation """ + torch.ops.xla.dynamo_set_buffer_donor_(y, True) x = x.view(-1, x.shape[-1]) - new_y = [] - # TODO fuse these kernels for slice_idx in range(len(lora_a_stacked)): y_s = y[slice_idx] lora_s = lora_a_stacked[slice_idx] - - y_org = y_s - y_s = y_s.view(-1, y_s.shape[-1]) - y_s = self.shrink(y_s, x, lora_s, scale) - y_s = y_s.view_as(y_org) - new_y.append(y_s) - return tuple(new_y) + y[slice_idx, :, :] = y_s = y_s + + return y def add_expand(self, y: torch.Tensor, From 0eeb72c52669c46924c6c77d0cc7352a07ee7ea8 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 9 Apr 2025 21:26:53 +0000 Subject: [PATCH 164/186] Removed extra compilation Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index c2b319a51854..4f5aef2d3b61 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -496,17 +496,6 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device) seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) - if self.lora_config is not None: - # We need to respect padding when activating LoRA adapters - padded_num_scheduled_tokens_per_req = np.copy( - num_scheduled_tokens_per_req - ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[-1] += \ - padded_total_num_scheduled_tokens - total_num_scheduled_tokens - - self.set_active_loras(self.input_batch, - padded_num_scheduled_tokens_per_req) - attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, @@ -527,6 +516,17 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Padded to avoid recompiling when `num_reqs` varies. logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) + + if self.lora_config is not None: + # We need to respect padding when activating LoRA adapters + padded_num_scheduled_tokens_per_req = np.copy( + num_scheduled_tokens_per_req + ) # Copying to avoid accidental state corruption bugs + padded_num_scheduled_tokens_per_req[-1] += \ + padded_total_num_scheduled_tokens - total_num_scheduled_tokens + + self.set_active_loras(self.input_batch, + padded_num_scheduled_tokens_per_req) return attn_metadata, logits_indices def _scatter_placeholders( @@ -891,6 +891,7 @@ def _dummy_run(self, num_tokens: int) -> None: def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, lora_requests) -> None: + xm.mark_step() # Captures input updates super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, lora_requests) xm.mark_step() # Captures metadata updates From c1be9fea4ec122eb45b7e13ba7f7454d2b28f7aa Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 9 Apr 2025 21:27:45 +0000 Subject: [PATCH 165/186] Replaced copies with buffer donation to reduce memory usage Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 36 +++++++++++++++++--------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index f6059072ca12..9abb6228b339 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -40,6 +40,15 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self._sampler_indices_padded = self._sampler_indices_padded.to( dtype=torch.int32) + torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, + True) + torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, + True) + def mark_compiled(self): torch._dynamo.mark_dynamic(self._token_lora_indices, 0) torch._dynamo.mark_dynamic(self._embeddings_indices, 1) @@ -340,18 +349,20 @@ def _update_base_metadata( "cpu", long_lora_context, ) - self._token_lora_indices[:base_indices.shape[0]].copy_( - base_indices.to(self.device)) - self._sampler_indices[:sampler_indices.shape[0]].copy_( - sampler_indices.to(self.device)) - self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded.to(self.device)) + self._token_lora_indices[:base_indices.shape[0]] = base_indices.to( + self.device) + self._sampler_indices[:sampler_indices.shape[0]] = sampler_indices.to( + self.device) + self._sampler_indices_padded[:sampler_indices_padded. + shape[0]] = sampler_indices_padded.to( + self.device) self._embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices.to(self.device)) + shape[0], :embeddings_indices. + shape[1]] = embeddings_indices.to(self.device) if long_lora_offsets_tensor is not None: - self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor.to(self.device)) + self._long_lora_indices[:long_lora_offsets_tensor. + shape[0]] = long_lora_offsets_tensor.to( + self.device) else: self._long_lora_indices.zero_() self.indices_len[:] = indices_len @@ -359,8 +370,9 @@ def _update_base_metadata( def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 - self._lora_indices_per_batch[:self.batch_size].copy_( - token_lora_tensor[:self.batch_size]) + self._lora_indices_per_batch[:self. + batch_size] = token_lora_tensor[:self. + batch_size] def _pad_prompt_mapping( self, prompt_mapping: Tuple[int, ...]) -> Tuple[int, ...]: From de5da33ab043a9fb42fc8153ebebf5e916607401 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 10 Apr 2025 11:34:49 +0000 Subject: [PATCH 166/186] Added explicit compilation in add_lora Signed-off-by: Akshat Tripathi --- vllm/v1/worker/tpu_model_runner.py | 38 ++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 4f5aef2d3b61..a7e234315b2c 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -20,6 +20,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.ops.xla_ops import LORA_RANK_BLOCK_SIZE +from vllm.lora.request import LoRARequest from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -1037,6 +1038,43 @@ def sample( sample(logits, sampling_metadata).sampled_token_ids) return out_tokens + def add_lora(self, lora_request: LoRARequest) -> bool: + success = super().add_lora(lora_request) + if not success: + return False + + # Only compile when we see a new LoRA adapter + logger.info("Compiling LoRA adapter %s", lora_request.path) + start = time.perf_counter() + xm.mark_step() + + for n in range(self.lora_config.max_loras): + logger.info(" --lora_index %d", n) + # Create n dummy LoRAs as padding + lora_requests: set[LoRARequest] = { + LoRARequest(lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path") + for lora_id in range(1, n + 1) + } + with self.lora_manager.dummy_lora_cache(): + # Add the dummy LoRAs here so _set_active_loras doesn't try to + # load from disk. + for lr in lora_requests: + self.lora_manager.add_dummy_lora( + lr, rank=self.LORA_WARMUP_RANK) + + lora_requests.add(lora_request) + + self.lora_manager._apply_adapters(lora_requests) + self.lora_manager.remove_all_adapters() + + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in in %.2f [secs].", end - start) + + return True + def _get_padded_number(n: int, multiple: int) -> int: return ((n + multiple - 1) // multiple) * multiple From 5adc67f5566787afc56aa5ef97908c0463682571 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 10 Apr 2025 11:48:16 +0000 Subject: [PATCH 167/186] Removed LoRA ID collision Signed-off-by: Akshat Tripathi --- vllm/v1/worker/lora_model_runner_mixin.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 5fcc70695492..c31d0f5a8b40 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -96,10 +96,12 @@ def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, num_reqs = len(num_scheduled_tokens) num_loras = lora_config.max_loras + base_lora_id = lora_config.max_loras + lora_config.max_cpu_loras + 1 + # Make prompt lora mapping # Assign LoRA IDs cyclically to simulate a worst-case scenario. prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % - num_loras) + 1 + num_loras) + base_lora_id # Make token lora mapping token_lora_mapping = np.repeat(prompt_lora_mapping, @@ -108,9 +110,9 @@ def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, # Make dummy lora requests lora_requests: set[LoRARequest] = { LoRARequest(lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, + lora_int_id=lora_id + base_lora_id, lora_path="/not/a/real/path") - for lora_id in range(1, num_loras + 1) + for lora_id in range(num_loras) } with self.lora_manager.dummy_lora_cache(): From 342ff8b4faf9cfcb909fbf6c919725a522b6d0a6 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 10 Apr 2025 13:13:28 +0000 Subject: [PATCH 168/186] Fix pre-commit Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 0b6867983db9..310f8c66cc5b 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -103,9 +103,8 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], for slice_idx in range(len(lora_a_stacked)): y_s = y[slice_idx] lora_s = lora_a_stacked[slice_idx] - y_s = self.shrink(y_s, x, lora_s, scale) - y[slice_idx, :, :] = y_s = y_s - + y_s = self.shrink(x, lora_s, scale) + y[slice_idx, :, :] = y_s # type: ignore[index] return y def add_expand(self, From fc65edb058edcfe340e56ac7e1f5a614f215d369 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 10 Apr 2025 13:13:56 +0000 Subject: [PATCH 169/186] Reduced number of iterations in test_lora Signed-off-by: Akshat Tripathi --- tests/tpu/test_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tpu/test_lora.py b/tests/tpu/test_lora.py index 2fafd9b1fc2d..70cf1dc25146 100644 --- a/tests/tpu/test_lora.py +++ b/tests/tpu/test_lora.py @@ -22,7 +22,7 @@ def test_lora_hotswapping(): prompt = "What is 1+1? \n" - for _ in range(10): + for _ in range(2): for i, req in enumerate(lora_requests): output = llm.generate(prompt, sampling_params=vllm.SamplingParams( From 7daaafae466e5e8b226a34128fb9807d0315a330 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 10 Apr 2025 13:34:45 +0000 Subject: [PATCH 170/186] Lint Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_tpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 310f8c66cc5b..19a01140df4d 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -103,7 +103,7 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], for slice_idx in range(len(lora_a_stacked)): y_s = y[slice_idx] lora_s = lora_a_stacked[slice_idx] - y_s = self.shrink(x, lora_s, scale) + y_s = self.shrink(y_s, x, lora_s, scale) y[slice_idx, :, :] = y_s # type: ignore[index] return y From 893ac041ef785c9dc53fcb735e3d00d400e70cda Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 11 Apr 2025 11:42:37 +0000 Subject: [PATCH 171/186] Reduced pallas kernel test size Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index 3490246d5991..b36b3c8f709c 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -5,15 +5,12 @@ # Required to register the custom ops import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import -N_TOKENS = [ - 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, - 131072 -] -HIDDEN_SIZES = [128, 256, 512, 896, 1024, 2048, 4096, 8192, 8320] - -DTYPES = [torch.float16, torch.bfloat16] -NUM_LORA = [1, 2, 4, 8, 16, 32] -RANKS = [8, 16, 32, 64, 128] +N_TOKENS = [16, 1024, 4096] +HIDDEN_SIZES = [1024, 2048, 4096] + +DTYPES = [torch.float32, torch.bfloat16] +NUM_LORA = [1, 4, 16] +RANKS = [32, 256, 512] def generate_test_data(T, D, L, N, seed, dtype=torch.float32): From 2a0fce733f448ba8aa1b843d06f2c3775fa4a7c0 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 11 Apr 2025 11:43:41 +0000 Subject: [PATCH 172/186] Added/removed comments Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/lora_ops.py | 41 ++++++++++++++++++++++++++ vllm/lora/punica_wrapper/punica_tpu.py | 5 ++-- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index f1b642e5852c..acbec0cfab9c 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -11,6 +11,22 @@ def bgmv_expand(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + + output_tensor (torch.Tensor): output tensor of shape + [num_tokens, hidden_size * num_slices]. + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + add_inputs (bool): Whether or not to add the input tensor to the output + tensor. + """ + outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) n_tokens = outputs.size(0) @@ -35,6 +51,16 @@ def bgmv_shrink(inputs: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + output_tensor (torch.Tensor): (Unused) output tensor (placeholder). + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + scaling (float, optional): Scalar multiplier applied to the output. + """ return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) @@ -47,6 +73,21 @@ def bgmv_expand_slice(inputs: torch.Tensor, slice_offset: int, slice_size: int, add_inputs: bool = True): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + + output_tensor (torch.Tensor): output tensor of shape + [num_tokens, hidden_size * num_slices]. + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + add_inputs (bool): Whether or not to add the input tensor to the output + tensor. + """ outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) n_tokens = outputs.size(0) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 19a01140df4d..49eddaa7876c 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -10,8 +10,6 @@ from .punica_base import PunicaWrapperBase -# The platforms that are compatible with the PyTorch-native implementation can -# inherit this class class PunicaWrapperTPU(PunicaWrapperBase): """ PunicaWrapperTPU is designed to manage and provide metadata for the punica @@ -319,7 +317,8 @@ def _apply_bias( return output.view_as(org_output) - def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + def _update_prefill_metadata(self, + token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 self._lora_indices_per_batch[:self.batch_size].copy_( token_lora_tensor[:self.batch_size]) From 4d4284449ead8e5b04c5f904b08b810d6af96705 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 11 Apr 2025 15:05:37 +0000 Subject: [PATCH 173/186] Fixed pallas kernel test Signed-off-by: Akshat Tripathi --- tests/lora/tpu/test_pallas_kernels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py index b36b3c8f709c..6430713158bd 100644 --- a/tests/lora/tpu/test_pallas_kernels.py +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -8,7 +8,7 @@ N_TOKENS = [16, 1024, 4096] HIDDEN_SIZES = [1024, 2048, 4096] -DTYPES = [torch.float32, torch.bfloat16] +DTYPES = [torch.float16] NUM_LORA = [1, 4, 16] RANKS = [32, 256, 512] @@ -70,4 +70,4 @@ def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): assert not torch.any(torch.isnan(output)) # Compare with reference output - assert torch.allclose(output, ref_output, rtol=1e-3, atol=1e-3) + assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) From 50a06fcfe00d3426d777534006a471d3591ebc8f Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 11 Apr 2025 15:20:56 +0000 Subject: [PATCH 174/186] Made LoRA e2e test more robust Signed-off-by: Akshat Tripathi --- tests/tpu/test_lora.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/tpu/test_lora.py b/tests/tpu/test_lora.py index 70cf1dc25146..20b7169910a4 100644 --- a/tests/tpu/test_lora.py +++ b/tests/tpu/test_lora.py @@ -1,9 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 +import pytest + import vllm from vllm.lora.request import LoRARequest -def test_lora_hotswapping(): +@pytest.fixture(scope="function", autouse=True) +def use_v1_only(monkeypatch: pytest.MonkeyPatch): + """ + Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1 + for all tests in this file + """ + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + yield + + +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_lora_e2e(num_loras: int): + """ + This test ensures that we can run with LoRA adapters on the TPU backend. + It verifies multiple capabilities: + 1. We can compile a model with LoRA adapters enabled + 2. We can run LoRA adapters + 3. We receive correct outputs when running with multiple LoRA adapters + 4. We can swap LoRA adapters between host and device + """ lora_name_template = \ "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" lora_requests = [ @@ -17,7 +39,7 @@ def test_lora_hotswapping(): max_seq_len_to_capture=256, max_num_seqs=8, enable_lora=True, - max_loras=2, + max_loras=num_loras, max_lora_rank=8) prompt = "What is 1+1? \n" From 177fcede4ccd8c2cd7a8a2be9fe0f41dbbbf4efe Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 22 Apr 2025 10:20:36 +0000 Subject: [PATCH 175/186] Added mark_steps to set_lora to break up large graphs Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 8b45d8b7b647..2d444b7a991f 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch_xla.core.xla_model as xm from transformers import PretrainedConfig from vllm.adapter_commons.layers import AdapterMapping @@ -213,6 +214,7 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], bias: Optional[torch.Tensor] = None, ): + xm.mark_step() self.reset_lora(index) self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( lora_a, non_blocking=True) @@ -235,6 +237,7 @@ def set_lora( )[self.embeddings_slice[0]:self.embeddings_slice[1]] assert self.embeddings_weights is not None self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + xm.mark_step() def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, @@ -381,6 +384,7 @@ def set_lora( # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers # store weights in a tuple of size 1. These two layers will # override this function. + xm.mark_step() assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1) @@ -404,6 +408,7 @@ def set_lora( assert len(self.lora_bias_stacked) self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( lora_bias.T, non_blocking=True) + xm.mark_step() def apply(self, x: torch.Tensor, @@ -706,6 +711,7 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], lora_bias: Optional[torch.Tensor] = None, ): + xm.mark_step() self.reset_lora(index) if self.tp_size > 1: @@ -733,6 +739,7 @@ def set_lora( 0, :lora_bias_i.shape[0]].copy_( lora_bias_i.T, non_blocking=True) + xm.mark_step() @classmethod @_not_fully_sharded_can_replace @@ -1078,6 +1085,7 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], bias: Optional[torch.Tensor] = None, ): + xm.mark_step() self.reset_lora(index) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( @@ -1091,6 +1099,7 @@ def set_lora( :embeddings_tensor.shape[0], :embeddings_tensor.shape[1], ] = embeddings_tensor + xm.mark_step() def _get_logits( self, From 12ff3642542bff8c092c78adf95ef58e2293d62a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 22 Apr 2025 10:21:11 +0000 Subject: [PATCH 176/186] Stopped index based recompilation for multi-lora Signed-off-by: Akshat Tripathi --- vllm/lora/layers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 2d444b7a991f..9bcb95e68930 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -215,6 +215,7 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): xm.mark_step() + index = torch.tensor([index], dtype=torch.int32, device="xla") self.reset_lora(index) self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( lora_a, non_blocking=True) @@ -385,6 +386,7 @@ def set_lora( # store weights in a tuple of size 1. These two layers will # override this function. xm.mark_step() + index = torch.tensor([index], dtype=torch.int32, device="xla") assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1) @@ -712,6 +714,7 @@ def set_lora( lora_bias: Optional[torch.Tensor] = None, ): xm.mark_step() + index = torch.tensor([index], dtype=torch.int32, device="xla") self.reset_lora(index) if self.tp_size > 1: @@ -1086,6 +1089,7 @@ def set_lora( bias: Optional[torch.Tensor] = None, ): xm.mark_step() + index = torch.tensor([index], dtype=torch.int32, device="xla") self.reset_lora(index) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( From 491578df08ddb3701f8cb7bed41f553a2afcf769 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 22 Apr 2025 10:22:41 +0000 Subject: [PATCH 177/186] Restored original maybe_dummy_run_with_lora Signed-off-by: Akshat Tripathi --- vllm/v1/worker/lora_model_runner_mixin.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index c31d0f5a8b40..a8a19e0e6206 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -14,7 +14,6 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor.models import supports_lora, supports_multimodal -from vllm.platforms import current_platform from vllm.v1.worker.gpu_input_batch import InputBatch logger = init_logger(__name__) @@ -96,12 +95,10 @@ def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, num_reqs = len(num_scheduled_tokens) num_loras = lora_config.max_loras - base_lora_id = lora_config.max_loras + lora_config.max_cpu_loras + 1 - # Make prompt lora mapping # Assign LoRA IDs cyclically to simulate a worst-case scenario. prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) % - num_loras) + base_lora_id + num_loras) + 1 # Make token lora mapping token_lora_mapping = np.repeat(prompt_lora_mapping, @@ -110,9 +107,9 @@ def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, # Make dummy lora requests lora_requests: set[LoRARequest] = { LoRARequest(lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id + base_lora_id, + lora_int_id=lora_id, lora_path="/not/a/real/path") - for lora_id in range(num_loras) + for lora_id in range(1, num_loras + 1) } with self.lora_manager.dummy_lora_cache(): @@ -129,10 +126,7 @@ def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, yield # __exit__ code - # Disabling remove_all_adapters on the TPU backend allows us to save - # quite a bit of RAM. E.g. we save 2.22 GB with Llama3.1 8B - if not current_platform.is_tpu(): - self.lora_manager.remove_all_adapters() + self.lora_manager.remove_all_adapters() def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: From 5cb4724f4116137b736205b77ca08547efc47280 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Tue, 22 Apr 2025 11:56:21 +0000 Subject: [PATCH 178/186] Split up into lora setup and lora selection functions Signed-off-by: Akshat Tripathi --- vllm/v1/worker/lora_model_runner_mixin.py | 57 +++++++++++++++++------ vllm/v1/worker/tpu_model_runner.py | 11 +++-- vllm/v1/worker/tpu_worker.py | 5 +- 3 files changed, 51 insertions(+), 22 deletions(-) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index a8a19e0e6206..35af83e9cc80 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -84,8 +84,38 @@ def set_active_loras(self, input_batch: InputBatch, lora_requests) @contextmanager - def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, - num_scheduled_tokens: np.ndarray): + def maybe_setup_dummy_loras(self, lora_config): + if lora_config is None: + yield + else: + # __enter__ code + assert self.lora_manager is not None, "LoRA is not enabled" + + num_loras = lora_config.max_loras + + # Make dummy lora requests + lora_requests: set[LoRARequest] = { + LoRARequest(lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path") + for lora_id in range(1, num_loras + 1) + } + + with self.lora_manager.dummy_lora_cache(): + # Add the dummy LoRAs here so _set_active_loras doesn't try to + # load from disk. + for lr in lora_requests: + self.lora_manager.add_dummy_lora( + lr, rank=self.LORA_WARMUP_RANK) + + yield + + # __exit__ code + self.lora_manager.remove_all_adapters() + + @contextmanager + def maybe_select_dummy_loras(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray): if lora_config is None: yield else: @@ -112,21 +142,18 @@ def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, for lora_id in range(1, num_loras + 1) } - with self.lora_manager.dummy_lora_cache(): - # Add the dummy LoRAs here so _set_active_loras doesn't try to - # load from disk. - for lr in lora_requests: - self.lora_manager.add_dummy_lora( - lr, rank=self.LORA_WARMUP_RANK) - - self._set_active_loras(tuple(prompt_lora_mapping), - tuple(token_lora_mapping), - lora_requests) + self._set_active_loras(tuple(prompt_lora_mapping), + tuple(token_lora_mapping), lora_requests) - yield + yield - # __exit__ code - self.lora_manager.remove_all_adapters() + @contextmanager + def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray): + with self.maybe_setup_dummy_loras( + lora_config), self.maybe_select_dummy_loras( + lora_config, num_scheduled_tokens): + yield def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 052f5f77f4d4..26ab544bc568 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -889,7 +889,7 @@ def _dummy_run(self, num_tokens: int) -> None: torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - with self.maybe_dummy_run_with_lora( + with self.maybe_select_dummy_loras( self.lora_config, np.array([num_tokens], dtype=np.int32)), set_forward_context( attn_metadata, self.vllm_config, 0): @@ -961,7 +961,7 @@ def _precompile_sample_from_hidden(self) -> None: generate_params_if_all_greedy, )) sampling_metadata.all_greedy = all_greedy - with self.maybe_dummy_run_with_lora( + with self.maybe_select_dummy_loras( self.lora_config, np.array([num_reqs], dtype=np.int32)): self.sample_from_hidden(dummy_hidden, sampling_metadata) @@ -976,9 +976,10 @@ def capture_model(self) -> None: Precompile all the subgraphs with possible input shapes. """ # TODO: precompile encoder - self._precompile_backbone() - self._precompile_select_hidden_states() - self._precompile_sample_from_hidden() + with self.maybe_setup_dummy_loras(self.lora_config): + self._precompile_backbone() + self._precompile_select_hidden_states() + self._precompile_sample_from_hidden() def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index eabe6c623852..99ceb61bad60 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -157,8 +157,9 @@ def determine_available_memory(self) -> int: self.vllm_config.compilation_config.static_forward_context, runner_kv_caches) - self.model_runner._dummy_run( - self.scheduler_config.max_num_batched_tokens) + with self.model_runner.maybe_setup_dummy_loras(self.lora_config): + self.model_runner._dummy_run( + self.scheduler_config.max_num_batched_tokens) # Synchronize before measuring the memory usage. xm.wait_device_ops() From cba82678ddf73a775cba399d58c260602958a910 Mon Sep 17 00:00:00 2001 From: xihajun Date: Tue, 22 Apr 2025 13:54:46 +0000 Subject: [PATCH 179/186] refactor mark step from layers into tpu model runner Signed-off-by: xihajun --- vllm/lora/layers.py | 12 ------------ vllm/v1/worker/tpu_model_runner.py | 27 +++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 9bcb95e68930..7a2c143bba1c 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -214,8 +214,6 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], bias: Optional[torch.Tensor] = None, ): - xm.mark_step() - index = torch.tensor([index], dtype=torch.int32, device="xla") self.reset_lora(index) self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( lora_a, non_blocking=True) @@ -238,7 +236,6 @@ def set_lora( )[self.embeddings_slice[0]:self.embeddings_slice[1]] assert self.embeddings_weights is not None self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) - xm.mark_step() def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, @@ -385,8 +382,6 @@ def set_lora( # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers # store weights in a tuple of size 1. These two layers will # override this function. - xm.mark_step() - index = torch.tensor([index], dtype=torch.int32, device="xla") assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1) @@ -410,7 +405,6 @@ def set_lora( assert len(self.lora_bias_stacked) self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( lora_bias.T, non_blocking=True) - xm.mark_step() def apply(self, x: torch.Tensor, @@ -713,8 +707,6 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], lora_bias: Optional[torch.Tensor] = None, ): - xm.mark_step() - index = torch.tensor([index], dtype=torch.int32, device="xla") self.reset_lora(index) if self.tp_size > 1: @@ -742,7 +734,6 @@ def set_lora( 0, :lora_bias_i.shape[0]].copy_( lora_bias_i.T, non_blocking=True) - xm.mark_step() @classmethod @_not_fully_sharded_can_replace @@ -1088,8 +1079,6 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], bias: Optional[torch.Tensor] = None, ): - xm.mark_step() - index = torch.tensor([index], dtype=torch.int32, device="xla") self.reset_lora(index) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( @@ -1103,7 +1092,6 @@ def set_lora( :embeddings_tensor.shape[0], :embeddings_tensor.shape[1], ] = embeddings_tensor - xm.mark_step() def _get_logits( self, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 26ab544bc568..c9e687cea2d9 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -43,6 +43,7 @@ from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from vllm.lora.layers import BaseLayerWithLoRA if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -827,6 +828,7 @@ def load_model(self) -> None: model = self.load_lora_model(model, self.model_config, self.scheduler_config, self.lora_config, self.device) + replace_set_lora(model) punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper if not self.enforce_eager: punica_wrapper.mark_compiled() @@ -1182,3 +1184,28 @@ def _create_dummy_scheduled_tokens(total_tokens: int, tokens[-1] += leftover_tokens return tokens + +def replace_set_lora(model): + def _tpu_set_lora( + self, + idx: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None + ): + xm.mark_step() + index = torch.tensor([idx], dtype=torch.int32, device="xla") + self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias) + xm.mark_step() + + def _tpu_reset_lora(self, idx: int): + index = torch.tensor([idx], dtype=torch.int32, device="xla") + self._original_reset_lora(index) + + for _, module in model.named_modules(): + if isinstance(module, BaseLayerWithLoRA): + module._original_set_lora = module.set_lora + module._original_reset_lora = module.reset_lora + module.set_lora = _tpu_set_lora.__get__(module, module.__class__) + module.reset_lora = _tpu_reset_lora.__get__(module, module.__class__) \ No newline at end of file From d4b370765a6f824c77e0d1fdc3a5441999276b14 Mon Sep 17 00:00:00 2001 From: Jorge de Freitas Date: Tue, 22 Apr 2025 13:54:46 +0000 Subject: [PATCH 180/186] refactor mark step from layers into tpu model runner Signed-off-by: Jorge de Freitas --- vllm/lora/layers.py | 12 ------------ vllm/v1/worker/tpu_model_runner.py | 27 +++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 9bcb95e68930..7a2c143bba1c 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -214,8 +214,6 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], bias: Optional[torch.Tensor] = None, ): - xm.mark_step() - index = torch.tensor([index], dtype=torch.int32, device="xla") self.reset_lora(index) self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( lora_a, non_blocking=True) @@ -238,7 +236,6 @@ def set_lora( )[self.embeddings_slice[0]:self.embeddings_slice[1]] assert self.embeddings_weights is not None self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) - xm.mark_step() def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, @@ -385,8 +382,6 @@ def set_lora( # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers # store weights in a tuple of size 1. These two layers will # override this function. - xm.mark_step() - index = torch.tensor([index], dtype=torch.int32, device="xla") assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1) @@ -410,7 +405,6 @@ def set_lora( assert len(self.lora_bias_stacked) self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( lora_bias.T, non_blocking=True) - xm.mark_step() def apply(self, x: torch.Tensor, @@ -713,8 +707,6 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], lora_bias: Optional[torch.Tensor] = None, ): - xm.mark_step() - index = torch.tensor([index], dtype=torch.int32, device="xla") self.reset_lora(index) if self.tp_size > 1: @@ -742,7 +734,6 @@ def set_lora( 0, :lora_bias_i.shape[0]].copy_( lora_bias_i.T, non_blocking=True) - xm.mark_step() @classmethod @_not_fully_sharded_can_replace @@ -1088,8 +1079,6 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], bias: Optional[torch.Tensor] = None, ): - xm.mark_step() - index = torch.tensor([index], dtype=torch.int32, device="xla") self.reset_lora(index) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( @@ -1103,7 +1092,6 @@ def set_lora( :embeddings_tensor.shape[0], :embeddings_tensor.shape[1], ] = embeddings_tensor - xm.mark_step() def _get_logits( self, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 26ab544bc568..c9e687cea2d9 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -43,6 +43,7 @@ from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from vllm.lora.layers import BaseLayerWithLoRA if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -827,6 +828,7 @@ def load_model(self) -> None: model = self.load_lora_model(model, self.model_config, self.scheduler_config, self.lora_config, self.device) + replace_set_lora(model) punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper if not self.enforce_eager: punica_wrapper.mark_compiled() @@ -1182,3 +1184,28 @@ def _create_dummy_scheduled_tokens(total_tokens: int, tokens[-1] += leftover_tokens return tokens + +def replace_set_lora(model): + def _tpu_set_lora( + self, + idx: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None + ): + xm.mark_step() + index = torch.tensor([idx], dtype=torch.int32, device="xla") + self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias) + xm.mark_step() + + def _tpu_reset_lora(self, idx: int): + index = torch.tensor([idx], dtype=torch.int32, device="xla") + self._original_reset_lora(index) + + for _, module in model.named_modules(): + if isinstance(module, BaseLayerWithLoRA): + module._original_set_lora = module.set_lora + module._original_reset_lora = module.reset_lora + module.set_lora = _tpu_set_lora.__get__(module, module.__class__) + module.reset_lora = _tpu_reset_lora.__get__(module, module.__class__) \ No newline at end of file From ed1738ad2ee066bbd78c8e90eb5b8afec77ded4c Mon Sep 17 00:00:00 2001 From: Jorge de Freitas Date: Fri, 25 Apr 2025 17:37:40 +0000 Subject: [PATCH 181/186] mask creation moved outside of matmul loop Signed-off-by: Jorge de Freitas --- vllm/lora/ops/xla_ops/pallas.py | 71 +++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 30 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 145c1d363774..2d4c0a2e0c35 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -150,9 +150,23 @@ def _bgmv_shrink_kernel(bT: int, bL: int, n_lora_lanes: int, lane_size: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, - out_ref, acc_ref, mask_ref): + out_ref, acc_ref, mask_ref, lanes_ref): + + t = pl.program_id(0) + l = pl.program_id(1) + d = pl.program_id(2) - @pl.when(pl.program_id(2) == 0) + @pl.when((t == 0) & (l == 0) & (d == 0)) + def _(): + lanes_ref[...] = jnp.zeros_like(lanes_ref[...], dtype=jnp.float32) + ones = jnp.ones((lane_size, ), dtype=jnp.float32) + + for i in range(n_lora_lanes): + start = i * lane_size + end = start + lane_size + lanes_ref.at[i,start:end].set(ones) + + @pl.when(d == 0) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) @@ -162,29 +176,17 @@ def _(): (((1, ), (1, )), ((), ())), preferred_element_type=jnp.float32) else: - t = pl.program_id(0) - - ones = jnp.ones((lane_size, ), dtype=jnp.float32) - - base_lora_idx = 0 - for lane_idx in range(max_num_loras): + def _run_lane_step(lane_idx, base_lora_idx): mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) - valid = False - for j in range(bT): - idx = idx_ref[j + bT * t] - for k in range(n_lora_lanes): - lora_idx = base_lora_idx + k - set_mask = idx == lora_idx - valid |= set_mask - - @pl.when(set_mask) - def _(): - lane_start = k * lane_size - lane_end = lane_start + lane_size - - mask_ref.at[j, lane_start:lane_end].set(ones) - base_lora_idx += n_lora_lanes + def _run_token_step(j, valid): + idx = idx_ref[j + bT * t] - base_lora_idx + set_mask = (idx >= 0) & (idx < n_lora_lanes) + @pl.when(set_mask) + def _(): + mask_ref.at[j, :].set(lanes_ref[idx]) + return valid | set_mask + valid = jax.lax.fori_loop(0, bT, _run_token_step, False) @pl.when(valid) def _(): @@ -192,6 +194,10 @@ def _(): inp_ref[...], lora_ref[lane_idx, ...], (((1, ), (1, )), ((), ())), preferred_element_type=jnp.float32) * mask_ref[...] + + return base_lora_idx + n_lora_lanes + _ = jax.lax.fori_loop(0, max_num_loras, _run_lane_step, 0) + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) def _(): @@ -233,7 +239,8 @@ def _bgmv_shrink( lambda i, j, k, block_idx: (i, j)), scratch_shapes=[ pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), - pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32) + pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), + pltpu.VMEM((N_LORA_LANES, LORA_BLOCK), jnp.float32) ]), compiler_params=pltpu.TPUCompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), @@ -345,22 +352,26 @@ def _(): ones = jnp.ones((bL, ), dtype=jnp.float32) - for i in range(max_num_loras): + def _run_lane_step(lane_idx, _): mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) - valid = False - for j in range(bT): - valid |= idx_ref[j + bT * t] == i - @pl.when(idx_ref[j + bT * t] == i) + def _run_token_step(j, valid): + set_mask = idx_ref[j + bT * t] == lane_idx + @pl.when(set_mask) def _(): mask_ref.at[j, :].set(ones) + return valid | set_mask + valid = jax.lax.fori_loop(0, bT, _run_token_step, False) @pl.when(valid) def _(): acc_ref[...] += jax.lax.dot( inp_ref[...], - lora_ref[i, ...], + lora_ref[lane_idx, ...], preferred_element_type=jnp.float32) * mask_ref[...] + + return + _ = jax.lax.fori_loop(0, max_num_loras, _run_lane_step, None) @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) def _(): From f81111ee32b1a7b4b6710c020e129f924cf1d63c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Mon, 28 Apr 2025 10:51:36 +0000 Subject: [PATCH 182/186] Moved mask setup outside of lora running loop Signed-off-by: Akshat Tripathi --- vllm/lora/ops/xla_ops/pallas.py | 94 ++++++++++++++++++--------------- 1 file changed, 51 insertions(+), 43 deletions(-) diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py index 2d4c0a2e0c35..f73cd29a26b7 100644 --- a/vllm/lora/ops/xla_ops/pallas.py +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -151,12 +151,12 @@ def _bgmv_shrink_kernel(bT: int, bL: int, n_lora_lanes: int, lane_size: int, max_num_loras: int, idx_ref, inp_ref, lora_ref, out_ref, acc_ref, mask_ref, lanes_ref): - - t = pl.program_id(0) - l = pl.program_id(1) - d = pl.program_id(2) - @pl.when((t == 0) & (l == 0) & (d == 0)) + t_idx = pl.program_id(0) + l_idx = pl.program_id(1) + d_idx = pl.program_id(2) + + @pl.when((t_idx == 0) & (l_idx == 0) & (d_idx == 0)) def _(): lanes_ref[...] = jnp.zeros_like(lanes_ref[...], dtype=jnp.float32) ones = jnp.ones((lane_size, ), dtype=jnp.float32) @@ -164,9 +164,9 @@ def _(): for i in range(n_lora_lanes): start = i * lane_size end = start + lane_size - lanes_ref.at[i,start:end].set(ones) + lanes_ref.at[i, start:end].set(ones) - @pl.when(d == 0) + @pl.when(d_idx == 0) def _(): acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) @@ -176,30 +176,34 @@ def _(): (((1, ), (1, )), ((), ())), preferred_element_type=jnp.float32) else: - def _run_lane_step(lane_idx, base_lora_idx): - mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) - - def _run_token_step(j, valid): - idx = idx_ref[j + bT * t] - base_lora_idx - set_mask = (idx >= 0) & (idx < n_lora_lanes) - @pl.when(set_mask) - def _(): - mask_ref.at[j, :].set(lanes_ref[idx]) - return valid | set_mask - valid = jax.lax.fori_loop(0, bT, _run_token_step, False) - - @pl.when(valid) + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + + def _mask_setup_step(i, valid): + idx = idx_ref[i + bT * t_idx] + inner_lane_idx = idx % n_lora_lanes + outer_lane_idx = idx // n_lora_lanes + + mask_ref.at[outer_lane_idx, i, :].set(lanes_ref[inner_lane_idx]) + + return valid | (1 << outer_lane_idx) + + valid = jax.lax.fori_loop(0, bT, _mask_setup_step, 0) + + def _lora_matmul_step(lane_idx, check_bit): + + @pl.when((valid & check_bit) > 0) def _(): acc_ref[...] += jax.lax.dot_general( inp_ref[...], lora_ref[lane_idx, ...], (((1, ), (1, )), ((), ())), - preferred_element_type=jnp.float32) * mask_ref[...] - - return base_lora_idx + n_lora_lanes - _ = jax.lax.fori_loop(0, max_num_loras, _run_lane_step, 0) - + preferred_element_type=jnp.float32) * mask_ref[lane_idx, + ...] - @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + return check_bit << 1 + + _ = jax.lax.fori_loop(0, max_num_loras, _lora_matmul_step, 1) + + @pl.when(d_idx == pl.num_programs(2) - 1) def _(): out_ref[...] = acc_ref[...].astype(out_ref.dtype) @@ -239,7 +243,7 @@ def _bgmv_shrink( lambda i, j, k, block_idx: (i, j)), scratch_shapes=[ pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), - pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), + pltpu.VMEM((N, TOKEN_BLOCK, LORA_BLOCK), jnp.float32), pltpu.VMEM((N_LORA_LANES, LORA_BLOCK), jnp.float32) ]), compiler_params=pltpu.TPUCompilerParams( @@ -352,26 +356,30 @@ def _(): ones = jnp.ones((bL, ), dtype=jnp.float32) - def _run_lane_step(lane_idx, _): - mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + + def _mask_setup_step(i, valid): + idx = idx_ref[i + bT * t] + lane_idx = idx % max_num_loras - def _run_token_step(j, valid): - set_mask = idx_ref[j + bT * t] == lane_idx - @pl.when(set_mask) - def _(): - mask_ref.at[j, :].set(ones) - return valid | set_mask - valid = jax.lax.fori_loop(0, bT, _run_token_step, False) + mask_ref.at[lane_idx, i, :].set(ones) + return valid | (1 << lane_idx) - @pl.when(valid) + valid = jax.lax.fori_loop(0, bT, _mask_setup_step, 0) + + def _lora_matmul_step(lane_idx, check_bit): + + @pl.when((valid & check_bit) > 0) def _(): acc_ref[...] += jax.lax.dot( inp_ref[...], lora_ref[lane_idx, ...], - preferred_element_type=jnp.float32) * mask_ref[...] - - return - _ = jax.lax.fori_loop(0, max_num_loras, _run_lane_step, None) + preferred_element_type=jnp.float32) * mask_ref[lane_idx, + ...] + + return check_bit << 1 + + _ = jax.lax.fori_loop(0, max_num_loras, _lora_matmul_step, 1) @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) def _(): @@ -408,11 +416,11 @@ def _bgmv_expand( lambda i, j, k, block_idx: (i, j)), scratch_shapes=[ pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), - pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32) + pltpu.VMEM((N, TOKEN_BLOCK, LORA_BLOCK), jnp.float32) ]), compiler_params=pltpu.TPUCompilerParams( dimension_semantics=("parallel", "parallel", "arbitrary")), - name="bgmv_pre_transpose")(idxs, inputs, loras) + name="bgmv_expand")(idxs, inputs, loras) def bgmv_expand_shape_function(idxs, inputs, loras): From 94bc0e20c2ab910587cd4a28866d093d361c2198 Mon Sep 17 00:00:00 2001 From: Jorge de Freitas Date: Tue, 29 Apr 2025 13:47:01 +0000 Subject: [PATCH 183/186] compile reset_lora function as separate graph Signed-off-by: Jorge de Freitas --- vllm/v1/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index c9e687cea2d9..3784979e7711 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1194,7 +1194,6 @@ def _tpu_set_lora( embeddings_tensor: Optional[torch.Tensor], bias: Optional[torch.Tensor] = None ): - xm.mark_step() index = torch.tensor([idx], dtype=torch.int32, device="xla") self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias) xm.mark_step() @@ -1202,6 +1201,7 @@ def _tpu_set_lora( def _tpu_reset_lora(self, idx: int): index = torch.tensor([idx], dtype=torch.int32, device="xla") self._original_reset_lora(index) + xm.mark_step() for _, module in model.named_modules(): if isinstance(module, BaseLayerWithLoRA): From b0bfc7a81e0329d6e7b8b5c678a4467cae6891dc Mon Sep 17 00:00:00 2001 From: Jorge de Freitas Date: Wed, 30 Apr 2025 11:22:07 +0000 Subject: [PATCH 184/186] update base metadata padding moved to cpu Signed-off-by: Jorge de Freitas --- vllm/lora/punica_wrapper/punica_tpu.py | 46 +++++++++++++++++--------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 9b4171f8da0b..3a596625eff9 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -349,25 +349,29 @@ def _update_base_metadata( extra_vocab_size, "cpu", long_lora_context, - ) - self._token_lora_indices[:base_indices.shape[0]] = base_indices.to( - self.device) - self._sampler_indices[:sampler_indices.shape[0]] = sampler_indices.to( - self.device) - self._sampler_indices_padded[:sampler_indices_padded. - shape[0]] = sampler_indices_padded.to( - self.device) - self._embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices. - shape[1]] = embeddings_indices.to(self.device) + ) + self._token_lora_indices = self._pad_to_shape( + base_indices, self._token_lora_indices.shape, dims=1 + ).to(self.device) + self._sampler_indices = self._pad_to_shape( + sampler_indices, self._sampler_indices.shape, dims=1 + ).to(self.device) + self._sampler_indices_padded = self._pad_to_shape( + sampler_indices_padded, self._sampler_indices_padded.shape, dims=1 + ).to(self.device) + self._embeddings_indices = self._pad_to_shape( + embeddings_indices, self._embeddings_indices.shape, dims=2 + ).to(self.device) if long_lora_offsets_tensor is not None: - self._long_lora_indices[:long_lora_offsets_tensor. - shape[0]] = long_lora_offsets_tensor.to( - self.device) + self._long_lora_indices = self._pad_to_shape( + long_lora_offsets_tensor, self._long_lora_indices.shape, dims=1 + ).to(self.device) else: - self._long_lora_indices.zero_() + zeroed = torch.zeros_like( + self._long_lora_indices.cpu(), dtype=torch.int32 + ) + self._long_lora_indices = zeroed.to(self.device) self.indices_len[:] = indices_len - xm.mark_step() def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 @@ -388,3 +392,13 @@ def _pad_prompt_mapping( padding = [-1] * pad_len return tuple(list(prompt_mapping) + padding) + + def _pad_to_shape(self, src, target_shape, dims=1): + if dims == 1: + pad_len = target_shape[0] - src.shape[0] + return F.pad(src, (0, pad_len), value=0).to(torch.int32) + else: + pad_rows = target_shape[0] - src.shape[0] + pad_cols = target_shape[1] - src.shape[1] + return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32) + From 966e80028a823d7735ed715a20bedb59d5dbcd2b Mon Sep 17 00:00:00 2001 From: Jorge de Freitas Date: Thu, 1 May 2025 16:47:19 +0000 Subject: [PATCH 185/186] fix size of sampler indices Signed-off-by: Jorge de Freitas --- vllm/lora/punica_wrapper/punica_tpu.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 3a596625eff9..6cc98f683746 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -270,11 +270,12 @@ def add_lora_logits(self, y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) - buffer = bgmv_shrink(x, lora_a_stacked, self.sampler_indices, scale) + sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) + buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale) y = bgmv_expand(buffer, lora_b_stacked, y, - self.sampler_indices, + sampler_indices, add_inputs=True) return y.view_as(y_org) From 22bafee9f397da06fda18875d234162927ccefed Mon Sep 17 00:00:00 2001 From: Jorge de Freitas Date: Tue, 6 May 2025 12:54:18 +0000 Subject: [PATCH 186/186] remove add lora Signed-off-by: Jorge de Freitas --- vllm/v1/worker/tpu_model_runner.py | 39 +----------------------------- 1 file changed, 1 insertion(+), 38 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 3784979e7711..d427fa32ac30 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1061,44 +1061,7 @@ def get_multimodal_embeddings(self, *args, **kwargs): def get_input_embeddings(self, *args, **kwargs): return self.model.get_input_embeddings(*args, **kwargs) - - def add_lora(self, lora_request: LoRARequest) -> bool: - success = super().add_lora(lora_request) - if not success: - return False - - # Only compile when we see a new LoRA adapter - logger.info("Compiling LoRA adapter %s", lora_request.path) - start = time.perf_counter() - xm.mark_step() - - for n in range(self.lora_config.max_loras): - logger.info(" --lora_index %d", n) - # Create n dummy LoRAs as padding - lora_requests: set[LoRARequest] = { - LoRARequest(lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path") - for lora_id in range(1, n + 1) - } - with self.lora_manager.dummy_lora_cache(): - # Add the dummy LoRAs here so _set_active_loras doesn't try to - # load from disk. - for lr in lora_requests: - self.lora_manager.add_dummy_lora( - lr, rank=self.LORA_WARMUP_RANK) - - lora_requests.add(lora_request) - - self.lora_manager._apply_adapters(lora_requests) - self.lora_manager.remove_all_adapters() - - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in in %.2f [secs].", end - start) - - return True - + def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: logger.info("Preparing request paddings:")