diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index eb82da3a883e..5dd53420dfdf 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -122,10 +122,8 @@ run_and_track_test 11 "test_struct_output_generate.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" run_and_track_test 12 "test_moe_pallas.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" - -# Disable the TPU LoRA tests until the feature is activated -# run_and_track_test 13 "test_lora (directory)" \ -# "python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/" +run_and_track_test 13 "test_lora.py" \ + "VLLM_XLA_CHECK_RECOMPILATION=0 python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/test_lora.py" # After all tests have been attempted, exit with the overall status. if [ "$overall_script_exit_code" -ne 0 ]; then diff --git a/tests/tpu/lora/test_pallas_kernels.py b/tests/tpu/lora/test_pallas_kernels.py deleted file mode 100644 index 8bd47de50c34..000000000000 --- a/tests/tpu/lora/test_pallas_kernels.py +++ /dev/null @@ -1,73 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import pytest -import torch - -# Required to register the custom ops -import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import - -N_TOKENS = [16, 1024, 4096] -HIDDEN_SIZES = [1024, 2048, 4096] - -DTYPES = [torch.bfloat16] -NUM_LORA = [1, 4, 16] -RANKS = [32, 256, 512] - - -def generate_test_data(T, D, L, N, seed, dtype=torch.float32): - """ - Inputs: (All integers) - T: Total number of tokens - D: Input dim - L: LoRA Dim - N: N LoRAs - - Outputs: - inputs: torch.Tensor - shape (T, D) - loras: torch.Tensor - shape (N, 1, L, D) - idxs: torch.Tensor - shape (T, ) - all values must be in [0, N) - - ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T - """ - torch.manual_seed(seed) - - inputs = torch.randn((T, D), device="xla", dtype=dtype) - loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype) - idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla") - - ref_output = ref_bgmv(inputs, loras, idxs) - return inputs, loras, idxs, ref_output - - -def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor): - selected_loras = loras[idxs] - if len(selected_loras.shape) == 4: - selected_loras = selected_loras.squeeze(axis=1) - - batch_size, output_size, input_size = selected_loras.shape - return (selected_loras @ inputs.reshape( - (batch_size, input_size, 1))).reshape((batch_size, output_size)) - - -# Parameterize tests with various shapes and dtypes -@pytest.mark.parametrize("T", N_TOKENS) -@pytest.mark.parametrize("D", HIDDEN_SIZES) -@pytest.mark.parametrize("L", RANKS) -@pytest.mark.parametrize("N", NUM_LORA) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", [0]) -def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): - if op_type == "expand": - D, L = L, D - - inputs, loras, idxs, ref_output = generate_test_data( - T, D, L, N, seed, dtype) - - # Run bgmv - output = torch.ops.xla.bgmv(inputs, loras, idxs) - - # Make sure we have no NaNs - assert not torch.any(torch.isnan(output)) - - # Compare with reference output - assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index af5cebdf2a8b..d3b1374a9dd2 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -200,7 +200,7 @@ def from_local_checkpoint( weights_mapper: Optional[WeightsMapper] = None, tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. - + Args: lora_dir: The local path that has lora data. expected_lora_modules: Name of modules that are expected to be @@ -620,7 +620,7 @@ def _match_target_modules(self, module_name: str): def _filter_unsupported_mm_module(self, module_name: str) -> bool: """ Regarding multimodal models, vLLM currently only supports adding LoRA to - language model. LoRA for other modules, such as the vision tower, will + language model. LoRA for other modules, such as the vision tower, will be filtered out. """ if self.supports_mm: diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py index acbec0cfab9c..dff4d5181efe 100644 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -1,63 +1,99 @@ # SPDX-License-Identifier: Apache-2.0 +import jax +import jax.numpy as jnp import torch +import torch.nn.functional as F +import torch_xla.core.xla_builder as xb +from torch.library import impl +from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard -# Required to register the custom ops -import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import +@jax.jit +def bgmv_jax(inputs, loras, idxs): + return jnp.einsum( + "td,tX,Xld->tl", + inputs, + jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype), + loras, + ) -def bgmv_expand(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True): + +XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") + + +@impl(XLA_LIB, "bgmv", "XLA") +def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + jax_import_guard() + return xb.call_jax(bgmv_jax, (inputs, loras, idxs)) + + +@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") +def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): + T, _ = inputs.shape + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + _, L, _ = loras.shape + + return torch.empty((T, L), device=inputs.device) + + +def bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +): """ Args: inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. - - lora_b_weights (torch.Tensor): LoRA weights of shape + + lora_b_weights (torch.Tensor): LoRA weights of shape [num_loras, lora_rank, hidden_size]. - - output_tensor (torch.Tensor): output tensor of shape + + output_tensor (torch.Tensor): output tensor of shape [num_tokens, hidden_size * num_slices]. - - lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] indicating which LoRA matrix to use for each token. - add_inputs (bool): Whether or not to add the input tensor to the output + add_inputs (bool): Whether or not to add the input tensor to the output tensor. """ outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - n_tokens = outputs.size(0) limit = output_tensor.shape[0] if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: limit = 1 - outputs = torch.cat( - (outputs, - torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]), - device=outputs.device)), - dim=1) + if output_tensor.shape[1] > outputs.shape[1]: + outputs = F.pad(outputs, + (0, output_tensor.shape[1] - outputs.shape[1], 0, 0)) if add_inputs: - return output_tensor + outputs[:limit, :] + return output_tensor + outputs[:limit, :output_tensor.shape[1]] else: - return outputs[:limit, :] + return outputs[:limit, :output_tensor.shape[1]] -def bgmv_shrink(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0): +def bgmv_shrink( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +): """ Args: inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. - lora_b_weights (torch.Tensor): LoRA weights of shape + lora_b_weights (torch.Tensor): LoRA weights of shape [num_loras, lora_rank, hidden_size]. output_tensor (torch.Tensor): (Unused) output tensor (placeholder). - lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] indicating which LoRA matrix to use for each token. scaling (float, optional): Scalar multiplier applied to the output. """ @@ -66,39 +102,41 @@ def bgmv_shrink(inputs: torch.Tensor, lora_indices_tensor) -def bgmv_expand_slice(inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True): +def bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +): """ Args: inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. - - lora_b_weights (torch.Tensor): LoRA weights of shape + + lora_b_weights (torch.Tensor): LoRA weights of shape [num_loras, lora_rank, hidden_size]. - - output_tensor (torch.Tensor): output tensor of shape + + output_tensor (torch.Tensor): output tensor of shape [num_tokens, hidden_size * num_slices]. - - lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] indicating which LoRA matrix to use for each token. - add_inputs (bool): Whether or not to add the input tensor to the output + add_inputs (bool): Whether or not to add the input tensor to the output tensor. """ outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - n_tokens = outputs.size(0) - outputs = torch.cat(( - torch.zeros((n_tokens, slice_offset), device=outputs.device), + outputs = F.pad( outputs, - torch.zeros( - (n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)), - device=outputs.device), - ), - dim=1) + ( + slice_offset, + output_tensor.shape[1] - (slice_offset + slice_size), + 0, + 0, + ), + ) if add_inputs: return output_tensor + outputs diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py deleted file mode 100644 index 35dc307539bf..000000000000 --- a/vllm/lora/ops/xla_ops/pallas.py +++ /dev/null @@ -1,133 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -import functools - -import jax -import jax.numpy as jnp -import torch -from jax.experimental import pallas as pl -from jax.experimental.pallas import tpu as pltpu -from torch.library import impl -from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, - make_kernel_from_pallas) - -# TODO: Tune these -TOKENS_BLOCK = 16 -LORA_RANK_BLOCK = 128 -DIM_BLOCK_SIZE = 128 - - -def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, - acc_ref, mask_ref): - - @pl.when(pl.program_id(2) == 0) - def _(): - acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) - - t = pl.program_id(0) - - for i in range(bT): - idx = idx_ref[i + bT * t] - mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) - mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32) - - acc_ref[...] += jax.lax.dot_general( - inp_ref[...], - lora_ref[idx, ...], (((1, ), (1, )), ((), ())), - preferred_element_type=jnp.float32) * mask_ref[...] - - @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) - def _(): - out_ref[...] = acc_ref[...].astype(out_ref.dtype) - - -@jax.jit -def _bgmv( - idxs: jax.Array, # (T, ) int32 - inputs: jax.Array, # (T, D) model dtype - loras: jax.Array # (N, L, D) model dtype -) -> jax.Array: # (T, L) model dtype - T, D = inputs.shape - N, L, _ = loras.shape - - return pl.pallas_call( - kernel=functools.partial(_bgmv_kernel, TOKENS_BLOCK, LORA_RANK_BLOCK), - out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), - grid_spec=pltpu.PrefetchScalarGridSpec( - num_scalar_prefetch=1, - grid=(T // TOKENS_BLOCK, L // LORA_RANK_BLOCK, - D // DIM_BLOCK_SIZE), - in_specs=[ - pl.BlockSpec((TOKENS_BLOCK, DIM_BLOCK_SIZE), - lambda i, j, k, block_idx: (i, k)), - pl.BlockSpec((N, LORA_RANK_BLOCK, DIM_BLOCK_SIZE), - lambda i, j, k, block_idx: (0, j, k)), - ], - out_specs=pl.BlockSpec((TOKENS_BLOCK, LORA_RANK_BLOCK), - lambda i, j, k, block_idx: (i, j)), - scratch_shapes=[ - pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32), - pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32) - ]), - compiler_params=pltpu.TPUCompilerParams( - dimension_semantics=("parallel", "parallel", "arbitrary")), - name="bgmv")(idxs, inputs, loras) - - -def bgmv_shape_function(idxs, inputs, loras): - T, _ = inputs.shape - _, L, _ = loras.shape - - return [((T, L), inputs.dtype)] - - -XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", ) - - -@impl(XLA_LIB, "bgmv", "XLA") -def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): - inputs = inputs.to(dtype=loras.dtype) - - if len(loras.shape) == 4: - loras = loras.squeeze(axis=1) - - jax_import_guard() - kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) - - T, _ = inputs.shape - _, L, D = loras.shape - - # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU - # register. This has to happen in pytorch, doing it in Jax will lead to NaNs - L1 = L - if LORA_RANK_BLOCK > L or L % LORA_RANK_BLOCK != 0: - L1 = (L // LORA_RANK_BLOCK + 1) * LORA_RANK_BLOCK - - D1 = D - if DIM_BLOCK_SIZE > D or D % DIM_BLOCK_SIZE != 0: - D1 = (D // DIM_BLOCK_SIZE + 1) * DIM_BLOCK_SIZE - - T1 = T - if TOKENS_BLOCK > T or T % TOKENS_BLOCK != 0: - T1 = (T // TOKENS_BLOCK + 1) * TOKENS_BLOCK - - if D1 != D or L1 != L: - loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0)) - if D1 != D or T1 != T: - inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T)) - if T1 != T: - idxs = torch.nn.functional.pad(idxs, ((0, T1 - T))) - - return kernel(idxs, inputs, loras)[:T, :L] - - -@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") -def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, - idxs: torch.IntTensor): - T, _ = inputs.shape - - if len(loras.shape) == 4: - loras = loras.squeeze(axis=1) - - _, L, _ = loras.shape - - return torch.empty((T, L), device=inputs.device) diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index f3153c6dab03..0556e583f409 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -1,11 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Union +import math +from typing import TYPE_CHECKING, Optional, Union import torch import torch.nn.functional as F +import torch_xla.core.xla_model as xm from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink +from vllm.lora.punica_wrapper.utils import convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext from .punica_base import PunicaWrapperBase @@ -31,6 +39,15 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self._sampler_indices_padded = self._sampler_indices_padded.to( dtype=torch.int32) + torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, + True) + torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, + True) + torch._dynamo.mark_dynamic(self._token_lora_indices, 0) torch._dynamo.mark_dynamic(self._embeddings_indices, 1) torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) @@ -55,15 +72,11 @@ def sampler_indices_padded(self) -> torch.Tensor: def shrink( self, - y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float, ): - if self.no_lora: - return y - return bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x), - scale) + return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale) def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool): @@ -72,7 +85,7 @@ def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, def expand_slice(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, - y_total_size: int, add_inputs: bool) -> torch.Tensor: + add_inputs: bool) -> torch.Tensor: return bgmv_expand_slice(x, w_t_all, y, self._get_token_lora_indices(x), y_offset, y_slice_size, add_inputs) @@ -98,9 +111,8 @@ def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], x = x.view(-1, x.shape[-1]) for slice_idx in range(len(lora_a_stacked)): - y_s = y[slice_idx] lora_s = lora_a_stacked[slice_idx] - y_s = self.shrink(y_s, x, lora_s, scale) + y_s = self.shrink(x, lora_s, scale) y[slice_idx, :, :] = y_s # type: ignore[index] return y @@ -140,15 +152,12 @@ def add_expand(self, y = self._apply_bias(self._get_token_lora_indices(y), y, output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): - y = self.expand_slice( - y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_left, - output_slices[slice_idx], - y_total_size=sum(output_slices), - add_inputs=add_inputs, - ) + y = self.expand_slice(y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs) offset_left += output_slices[slice_idx] return y.view_as(y_org) @@ -216,12 +225,10 @@ def add_lora_linear(self, if buffer is None: r = lora_b_stacked[0].size(-1) - # We set the buffer to be float32 by default, consistent with the - # triton op T = x.size(0) buffer = torch.zeros( (len(output_slices), T, r), - dtype=torch.float32, + dtype=x.dtype, device=x.device, ) buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) @@ -257,26 +264,16 @@ def add_lora_logits(self, scale (float): Scaling factor. buffer (Optional[torch.Tensor]):Default to None. """ - if self.no_lora: - return y - y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) - r = lora_b_stacked.size(-1) - if buffer is None: - # We set the buffer to be float32 by default, consistent with the - # triton op - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - - buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, - scale) + + sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) + buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale) y = bgmv_expand(buffer, lora_b_stacked, y, - self.sampler_indices, + sampler_indices, add_inputs=True) return y.view_as(y_org) @@ -316,10 +313,92 @@ def _apply_bias( return output.view_as(org_output) + # This performs the same tensor ops as the base method, except it does them + # on the CPU then transfers the results to the TPU + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + # Make sure we don't accidentally collect outside operations + xm.mark_step() + + # Pad the prompt mapping to avoid running into recompiles on the TPU + # TODO: Should this happen inside mapping internally? If so how can we + # avoid having backend specific LoRAMapping classes? + mapping.prompt_mapping = self._pad_prompt_mapping( + mapping.prompt_mapping) + + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + "cpu", + long_lora_context, + ) + self._token_lora_indices = self._pad_to_shape( + base_indices, self._token_lora_indices.shape, + dims=1).to(self.device) + self._sampler_indices = self._pad_to_shape(sampler_indices, + self._sampler_indices.shape, + dims=1).to(self.device) + self._sampler_indices_padded = self._pad_to_shape( + sampler_indices_padded, self._sampler_indices_padded.shape, + dims=1).to(self.device) + self._embeddings_indices = self._pad_to_shape( + embeddings_indices, self._embeddings_indices.shape, + dims=2).to(self.device) + if long_lora_offsets_tensor is not None: + self._long_lora_indices = self._pad_to_shape( + long_lora_offsets_tensor, + self._long_lora_indices.shape, + dims=1).to(self.device) + else: + zeroed = torch.zeros_like(self._long_lora_indices.cpu(), + dtype=torch.int32) + self._long_lora_indices = zeroed.to(self.device) + self.indices_len[:] = indices_len + def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None: self.batch_size = 1 - self._lora_indices_per_batch[:self.batch_size].copy_( - token_lora_tensor[:self.batch_size]) - # TODO: .item() is extremely inefficient on TPU, so find a way around it - self.no_lora = torch.all(token_lora_tensor == -1).item() + self._lora_indices_per_batch[:self. + batch_size] = token_lora_tensor[:self. + batch_size] + + def _pad_prompt_mapping( + self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: + num_reqs = len(prompt_mapping) + + # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular + # import + MIN_NUM_SEQS = 8 + + padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS) + pad_len = padded_num_reqs - num_reqs + + padding = [-1] * pad_len + return tuple(list(prompt_mapping) + padding) + + def _pad_to_shape(self, src, target_shape, dims=1): + if dims == 1: + pad_len = target_shape[0] - src.shape[0] + return F.pad(src, (0, pad_len), value=0).to(torch.int32) + else: + pad_rows = target_shape[0] - src.shape[0] + pad_cols = target_shape[1] - src.shape[1] + return F.pad(src, (0, pad_cols, 0, pad_rows), + value=0).to(torch.int32) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 3cbab840e969..eb8ed622161d 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -80,8 +80,38 @@ def set_active_loras(self, input_batch: InputBatch, lora_requests) @contextmanager - def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, - num_scheduled_tokens: np.ndarray): + def maybe_setup_dummy_loras(self, lora_config): + if lora_config is None: + yield + else: + # __enter__ code + assert self.lora_manager is not None, "LoRA is not enabled" + + num_loras = lora_config.max_loras + + # Make dummy lora requests + lora_requests: set[LoRARequest] = { + LoRARequest(lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path") + for lora_id in range(1, num_loras + 1) + } + + with self.lora_manager.dummy_lora_cache(): + # Add the dummy LoRAs here so _set_active_loras doesn't try to + # load from disk. + for lr in lora_requests: + self.lora_manager.add_dummy_lora( + lr, rank=self.LORA_WARMUP_RANK) + + yield + + # __exit__ code + self.lora_manager.remove_all_adapters() + + @contextmanager + def maybe_select_dummy_loras(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray): if lora_config is None: yield else: @@ -108,21 +138,18 @@ def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, for lora_id in range(1, num_loras + 1) } - with self.lora_manager.dummy_lora_cache(): - # Add the dummy LoRAs here so _set_active_loras doesn't try to - # load from disk. - for lr in lora_requests: - self.lora_manager.add_dummy_lora( - lr, rank=self.LORA_WARMUP_RANK) - - self._set_active_loras(tuple(prompt_lora_mapping), - tuple(token_lora_mapping), - lora_requests) + self._set_active_loras(tuple(prompt_lora_mapping), + tuple(token_lora_mapping), lora_requests) - yield + yield - # __exit__ code - self.lora_manager.remove_all_adapters() + @contextmanager + def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray): + with self.maybe_setup_dummy_loras( + lora_config), self.maybe_select_dummy_loras( + lora_config, num_scheduled_tokens): + yield def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 46bcf64ed0c3..669908cb577b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -20,6 +20,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.lora.layers import BaseLayerWithLoRA from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, @@ -152,6 +153,9 @@ def __init__( self.hidden_size = model_config.get_hidden_size() self.vocab_size = model_config.get_vocab_size() + if self.lora_config is not None: + self.vocab_size += self.lora_config.lora_extra_vocab_size + # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope @@ -591,6 +595,17 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) + if self.lora_config is not None: + # We need to respect padding when activating LoRA adapters + padded_num_scheduled_tokens_per_req = np.copy( + num_scheduled_tokens_per_req + ) # Copying to avoid accidental state corruption bugs + padded_num_scheduled_tokens_per_req[-1] += \ + padded_total_num_scheduled_tokens - total_num_scheduled_tokens + + self.set_active_loras(self.input_batch, + padded_num_scheduled_tokens_per_req) + layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys() per_layer_attn_metadata = { @@ -916,6 +931,7 @@ def load_model(self) -> None: model = self.load_lora_model(model, self.model_config, self.scheduler_config, self.lora_config, self.device) + replace_set_lora(model) # Sync all pending XLA execution during model initialization and weight # loading. @@ -980,7 +996,7 @@ def _dummy_run(self, num_tokens: int) -> None: for layer_name in layer_names } - with self.maybe_dummy_run_with_lora( + with self.maybe_select_dummy_loras( self.lora_config, np.array([num_tokens], dtype=np.int32)), set_forward_context( per_layer_attn_metadata, self.vllm_config, 0): @@ -989,6 +1005,13 @@ def _dummy_run(self, num_tokens: int) -> None: inputs_embeds=inputs_embeds) self._hidden_states_dtype = out.dtype + def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, + lora_requests) -> None: + xm.mark_step() # Captures input updates + super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, + lora_requests) + xm.mark_step() # Captures metadata updates + def _precompile_mm_encoder(self) -> None: # Pre-compile MM encoder for all supported data modalities. hf_config = self.vllm_config.model_config.hf_config @@ -1151,7 +1174,10 @@ def _precompile_sample_from_logits(self) -> None: generate_params_if_all_greedy, )) sampling_metadata.all_greedy = all_greedy - self.sample_from_logits(dummy_logits, sampling_metadata) + with self.maybe_select_dummy_loras( + self.lora_config, np.array([num_reqs], + dtype=np.int32)): + self.sample_from_logits(dummy_logits, sampling_metadata) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() @@ -1167,7 +1193,9 @@ def _precompile_gather_logprobs(self) -> None: dtype=self._hidden_states_dtype) dummy_tokens = torch.zeros((num_reqs, 1), dtype=torch.int64).to(self.device) - self.gather_logprobs(dummy_logits, dummy_tokens) + with self.maybe_select_dummy_loras( + self.lora_config, np.array([num_reqs], dtype=np.int32)): + self.gather_logprobs(dummy_logits, dummy_tokens) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() @@ -1178,13 +1206,14 @@ def capture_model(self) -> None: """ Precompile all the subgraphs with possible input shapes. """ - self._precompile_mm_encoder() - self._precompile_backbone() - self._precompile_select_hidden_states() - self._precompile_compute_logits() - self._precompile_structured_decoding() - self._precompile_sample_from_logits() - self._precompile_gather_logprobs() + with self.maybe_setup_dummy_loras(self.lora_config): + self._precompile_mm_encoder() + self._precompile_backbone() + self._precompile_select_hidden_states() + self._precompile_compute_logits() + self._precompile_structured_decoding() + self._precompile_sample_from_logits() + self._precompile_gather_logprobs() def profile_run( self, @@ -1467,11 +1496,11 @@ def _get_token_paddings(min_token_size: int, max_token_size: int, padding_gap: int) -> list[int]: """Generate a list of padding size, starting from min_token_size, ending with a number that can cover max_token_size - + If padding_gap == 0 then: increase 2X each time (exponential) else: - first increase the size to twice, + first increase the size to twice, then increase the padding size by padding_gap. """ # assert min_token_size is power of 2 @@ -1508,3 +1537,32 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: index = bisect.bisect_left(paddings, x) assert index < len(paddings) return paddings[index] + + +def replace_set_lora(model): + + def _tpu_set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, + ): + # TODO: The integer index leads to a recompilation, but converting it + # to a tensor doesn't seem to work anymore. This might be fixed with a + # later release of torch_xla. + self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias) + xm.mark_step() + + def _tpu_reset_lora(self, index: int): + self._original_reset_lora(index) + xm.mark_step() + + for _, module in model.named_modules(): + if isinstance(module, BaseLayerWithLoRA): + module._original_set_lora = module.set_lora + module._original_reset_lora = module.reset_lora + module.set_lora = _tpu_set_lora.__get__(module, module.__class__) + module.reset_lora = _tpu_reset_lora.__get__( + module, module.__class__) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index fa4eb30ccd9a..0707e17afe7a 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -83,10 +83,6 @@ def __init__( if self.model_config.seed is None: self.model_config.seed = 0 - if vllm_config.lora_config is not None: - raise NotImplementedError( - "The V1 TPU backend doesn't support LoRA serving") - def init_device(self): os.environ["PJRT_DEVICE"] = "TPU" # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D @@ -166,7 +162,8 @@ def determine_available_memory(self) -> int: runner_kv_caches) # `max_num_tokens >= max_num_batched_tokens` due to padding. - self.model_runner.profile_run(self.model_runner.max_num_tokens) + with self.model_runner.maybe_setup_dummy_loras(self.lora_config): + self.model_runner.profile_run(self.model_runner.max_num_tokens) # Synchronize before measuring the memory usage. xm.wait_device_ops()