diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 87f74277cf90..cfab6b3ef159 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -41,6 +41,8 @@ docker run --privileged --net host --shm-size=16G -it \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \ && echo TEST_9 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \ + && echo TEST_10 \ + && pytest -s -v /workspace/vllm/tests/tpu/test_lora.py" \ # TODO: This test fails because it uses RANDOM_SEED sampling diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index dc433f9dad26..b940f7190bb2 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -47,7 +47,7 @@ def dist_init(): temp_file = tempfile.mkstemp()[1] backend = "nccl" - if current_platform.is_cpu(): + if current_platform.is_cpu() or current_platform.is_tpu(): backend = "gloo" init_distributed_environment(world_size=1, diff --git a/tests/lora/tpu/__init__.py b/tests/lora/tpu/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/lora/tpu/test_pallas_kernels.py b/tests/lora/tpu/test_pallas_kernels.py new file mode 100644 index 000000000000..0a7f58914ace --- /dev/null +++ b/tests/lora/tpu/test_pallas_kernels.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +# Required to register the custom ops +import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import + +N_TOKENS = [16, 1024, 4096] +HIDDEN_SIZES = [1024, 2048, 4096] + +DTYPES = [torch.float16] +NUM_LORA = [1, 4, 16] +RANKS = [32, 256, 512] + + +def generate_test_data(T, D, L, N, seed, dtype=torch.float32): + """ + Inputs: (All integers) + T: Total number of tokens + D: Input dim + L: LoRA Dim + N: N LoRAs + + Outputs: + inputs: torch.Tensor - shape (T, D) + loras: torch.Tensor - shape (N, 1, L, D) + idxs: torch.Tensor - shape (T, ) - all values must be in [0, N) + + ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T + """ + torch.manual_seed(seed) + + inputs = torch.randn((T, D), device="xla", dtype=dtype) + loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype) + idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla") + + ref_output = ref_bgmv(inputs, loras, idxs) + return inputs, loras, idxs, ref_output + + +def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor): + selected_loras = loras[idxs] + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(axis=1) + + batch_size, output_size, input_size = selected_loras.shape + return (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) + + +# Parameterize tests with various shapes and dtypes +@pytest.mark.parametrize("T", N_TOKENS) +@pytest.mark.parametrize("D", HIDDEN_SIZES) +@pytest.mark.parametrize("L", RANKS) +@pytest.mark.parametrize("N", NUM_LORA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", [0]) +def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): + if op_type == "expand": + D, L = L, D + + inputs, loras, idxs, ref_output = generate_test_data( + T, D, L, N, seed, dtype) + + # Run bgmv + if op_type == "shrink": + output = torch.ops.xla.bgmv_shrink(inputs, loras, idxs) + else: + output = torch.ops.xla.bgmv_expand(inputs, loras.transpose(2, 3), idxs) + + # Make sure we have no NaNs + assert not torch.any(torch.isnan(output)) + + # Compare with reference output + assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) + + +# Parameterize tests with various shapes and dtypes +@pytest.mark.parametrize("T", N_TOKENS) +@pytest.mark.parametrize("D", HIDDEN_SIZES) +@pytest.mark.parametrize("L", RANKS) +@pytest.mark.parametrize("N", NUM_LORA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", [0]) +def test_lora_laning_correctness(T, D, L, N, dtype, seed): + inputs, loras_a, idxs, _ = generate_test_data(T, D, L, N, seed, dtype) + _, loras_b, _, _ = generate_test_data(T, L, D, N, seed, dtype) + + r1 = ref_bgmv(inputs, loras_a, idxs) + r2 = ref_bgmv(r1, loras_b, idxs) + + o1 = torch.ops.xla.bgmv_shrink(inputs, loras_a, idxs) + o2 = torch.ops.xla.bgmv_expand(o1, loras_b.transpose(2, 3), idxs) + + # Compare with reference output + assert torch.allclose(o2, r2, rtol=1e-2, atol=1e-2) diff --git a/tests/tpu/test_lora.py b/tests/tpu/test_lora.py new file mode 100644 index 000000000000..20b7169910a4 --- /dev/null +++ b/tests/tpu/test_lora.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +import vllm +from vllm.lora.request import LoRARequest + + +@pytest.fixture(scope="function", autouse=True) +def use_v1_only(monkeypatch: pytest.MonkeyPatch): + """ + Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1 + for all tests in this file + """ + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + yield + + +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_lora_e2e(num_loras: int): + """ + This test ensures that we can run with LoRA adapters on the TPU backend. + It verifies multiple capabilities: + 1. We can compile a model with LoRA adapters enabled + 2. We can run LoRA adapters + 3. We receive correct outputs when running with multiple LoRA adapters + 4. We can swap LoRA adapters between host and device + """ + lora_name_template = \ + "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" + lora_requests = [ + LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) + for i in range(1, 5) + ] + + llm = vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", + num_scheduler_steps=1, + max_model_len=256, + max_seq_len_to_capture=256, + max_num_seqs=8, + enable_lora=True, + max_loras=num_loras, + max_lora_rank=8) + + prompt = "What is 1+1? \n" + + for _ in range(2): + for i, req in enumerate(lora_requests): + output = llm.generate(prompt, + sampling_params=vllm.SamplingParams( + max_tokens=256, temperature=0), + lora_request=req)[0].outputs[0].text + assert int(output.strip()[0]) == i + 1 diff --git a/vllm/config.py b/vllm/config.py index 2662c6a84990..68d40346f9c9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2416,8 +2416,8 @@ class LoRAConfig: max_cpu_loras: Optional[int] = None lora_dtype: Optional[Union[torch.dtype, str]] = None lora_extra_vocab_size: int = 256 - # This is a constant. - lora_vocab_padding_size: ClassVar[int] = 256 + lora_vocab_padding_size: ClassVar[int] = current_platform\ + .get_lora_vocab_padding_size() long_lora_scaling_factors: Optional[tuple[float]] = None bias_enabled: bool = False @@ -2439,6 +2439,7 @@ def compute_hash(self) -> str: factors.append(self.fully_sharded_loras) factors.append(self.lora_dtype) factors.append(self.lora_extra_vocab_size) + factors.append(self.lora_vocab_padding_size) factors.append(self.long_lora_scaling_factors) factors.append(self.bias_enabled) hash_str = hashlib.md5(str(factors).encode(), diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index 41e1ec94145d..e195f8cf5e8e 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -16,6 +16,7 @@ MergedQKVParallelLinearWithLoRA, QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA) +from vllm.platforms import current_platform if TYPE_CHECKING: pass @@ -57,15 +58,25 @@ def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): device=x.device, ) - layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0) + shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( + buffers, x, layer.lora_a_stacked, 1.0) + + if not current_platform.can_update_inplace(): + buffers = shrunk_buffers + buffers = tensor_model_parallel_all_gather(buffers) - layer.punica_wrapper.add_expand(output, - buffers, - layer.lora_b_stacked, - layer.lora_bias_stacked, - layer.output_slices, - offset_start=0, - add_input=True) + + lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( + output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + if not current_platform.can_update_inplace(): + output = lora_output output = output.view(*out_orig_shape) # now have column partitioned and packed output @@ -292,7 +303,11 @@ def apply(self, device=x.device, ) - self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) + shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( + buffer, x, self.lora_a_stacked, 1.0) + if not current_platform.can_update_inplace(): + buffer = shrunk_buffer + buffer = tensor_model_parallel_all_reduce(buffer) # following S-LoRA, allows the fusing of all_gather and all_reduce @@ -304,7 +319,7 @@ def apply(self, # NOTE offset are based on the rank. shard_size = self.lora_b_stacked[0].shape[2] offset_start = self.tp_rank * shard_size - self.punica_wrapper.add_expand( + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( output, buffer, self.lora_b_stacked, @@ -313,6 +328,10 @@ def apply(self, offset_start=offset_start, add_input=True, ) + + if not current_platform.can_update_inplace(): + output = lora_output + output = output.view(*out_orig_shape) return output diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 7a9d5237ab75..7a2c143bba1c 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torch_xla.core.xla_model as xm from transformers import PretrainedConfig from vllm.adapter_commons.layers import AdapterMapping @@ -261,10 +262,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings.shape[1], -1, ) - self.punica_wrapper.add_lora_embedding(full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_embedding( + full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + + if not current_platform.can_update_inplace(): + full_output = lora_output + return full_output.view_as(full_output_org) @classmethod @@ -410,10 +418,13 @@ def apply(self, output = output.flatten(0, 1) x = x.flatten(0, 1) - self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, - self.lora_b_stacked, - self.lora_bias_stacked, 1.0, - self.output_slices) + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.lora_bias_stacked, 1.0, self.output_slices) + if not current_platform.can_update_inplace(): + output = lora_output + return output @property @@ -1128,15 +1139,23 @@ def _get_logits( torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1]) - lora_logits[-1] = float("-inf") + + neg_inf, pos_inf = current_platform.get_infinity_values( + lora_logits.dtype) + + lora_logits[-1] = neg_inf lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded + + if current_platform.is_tpu(): + indices_padded = indices_padded[:logits.size(0)] + lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], - ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), - posinf=float("inf"), - neginf=float("-inf"))) + ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, + posinf=pos_inf, + neginf=neg_inf)) # HPU needs special handling to prune out dummy samples. if current_platform.is_hpu(): @@ -1146,10 +1165,13 @@ def _get_logits( self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits - # LogitsProcessorWithLoRA always using bgmv - self.punica_wrapper.add_lora_logits(logits, hidden_states, - self.lora_a_stacked, - self.lora_b_stacked, 1.0) + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_logits( + logits, hidden_states, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + + if not current_platform.can_update_inplace(): + logits = lora_output # Remove paddings in vocab (if any). logits = logits[:, :self.base_layer.vocab_size] diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 8164d919ca8b..8b777cdc086b 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -199,7 +199,7 @@ def from_local_checkpoint( weights_mapper: Optional[WeightsMapper] = None, ) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. - + Args: lora_dir: The local path that has lora data. expected_lora_modules: Name of modules that are expected to be @@ -605,7 +605,7 @@ def _match_target_modules(self, module_name: str): def _filter_unsupported_mm_module(self, module_name: str) -> bool: """ Regarding multimodal models, vLLM currently only supports adding LoRA to - language model. LoRA for other modules, such as the vision tower, will + language model. LoRA for other modules, such as the vision tower, will be filtered out. """ if self.supports_mm: diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py new file mode 100644 index 000000000000..2b6337d8fd8f --- /dev/null +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink) +from vllm.lora.ops.xla_ops.pallas import LORA_RANK_BLOCK_SIZE + +__all__ = [ + "bgmv_expand", "bgmv_expand_slice", "bgmv_shrink", "LORA_RANK_BLOCK_SIZE" +] diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py new file mode 100644 index 000000000000..1083eb3ea762 --- /dev/null +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn.functional as F + +# Required to register the custom ops +import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + + output_tensor (torch.Tensor): output tensor of shape + [num_tokens, hidden_size * num_slices]. + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + add_inputs (bool): Whether or not to add the input tensor to the output + tensor. + """ + + outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3), + lora_indices_tensor) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + if output_tensor.shape[1] > outputs.shape[1]: + outputs = F.pad(outputs, + (0, output_tensor.shape[1] - outputs.shape[1], 0, 0)) + + if add_inputs: + return output_tensor + outputs[:limit, :output_tensor.shape[1]] + else: + return outputs[:limit, :output_tensor.shape[1]] + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + output_tensor (torch.Tensor): (Unused) output tensor (placeholder). + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + scaling (float, optional): Scalar multiplier applied to the output. + """ + + return scaling * torch.ops.xla.bgmv_shrink(inputs, lora_b_weights, + lora_indices_tensor) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + + output_tensor (torch.Tensor): output tensor of shape + [num_tokens, hidden_size * num_slices]. + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + add_inputs (bool): Whether or not to add the input tensor to the output + tensor. + """ + outputs = torch.ops.xla.bgmv_expand(inputs, lora_b_weights.transpose(2, 3), + lora_indices_tensor) + + outputs = F.pad(outputs, (slice_offset, output_tensor.shape[1] - + (slice_offset + slice_size), 0, 0)) + + if add_inputs: + return output_tensor + outputs + else: + return outputs diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py new file mode 100644 index 000000000000..f73cd29a26b7 --- /dev/null +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -0,0 +1,517 @@ +# SPDX-License-Identifier: Apache-2.0 +import functools +import math +from typing import List + +import jax +import jax.numpy as jnp +import torch +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from torch.library import impl +from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, + make_kernel_from_pallas) + +# Ignore "Function definition does not bind loop variable" errors in Pallas +#ruff: noqa: B023 + +XLA_LIB.define( + "bgmv_shrink(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") + +# bgmv_expand needs a flag to enable LoRA laning since it expects its inputs to +# be the outputs of a LoRA laned bgmv_shrink. This is not always the case when +# we use bgmv_expand +XLA_LIB.define( + "bgmv_expand(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") +""" +LoRA Laning Optimization for TPU Matrix Multiplication + +When we run with the TPU we need to keep its MXU (matrix multiplication unit) +well fed to achieve maximum utilisation. +The MXU can perform an (8x128) by (128x128) matmul once every 8 cycles. + +LoRA computations typically take a series of T (1xD) vectors and matmul them +with a (DxL) matrix (shrinking) followed by another matmul with a (LxD) matrix +(expanding). Grouping the vectors we get a (TxD) matrix, so our computations +become matmul((TxD), (DxL)) and matmul((TxL), (LxD)). + +The number of tokens (T) and the hidden dimension (D) are usually greater than +8 and 128 respectively, however the LoRA rank (L) is usually a smaller value, +around 8-64, which means we need to pad L to allow it to fit in a TPU register. + + +------------------+ + | Shrink Operation | + +------------------+ + + L + +------------------+ + D | 1111000000000000 | L + +------------------+ | 1111000000000000 | +------------------+ + | 1111111111111111 | | 1111000000000000 | | 1111000000000000 | + T | 2222222222222222 | x D | 1111000000000000 | = T | 1111000000000000 | + +------------------+ | 1111000000000000 | +------------------+ + | 1111000000000000 | + | 1111000000000000 | + | 1111000000000000 | + +------------------+ + +Here we have 4 tokens each needing a different LoRA adapter, and 1 LoRA adapter +loaded into the MXU. After the matmul we end up with the result of applying +LoRA 1 to all T tokens, but since only one token needs LoRA 1, we mask out +everything we don't need to get: + + D + +------------------+ + | 1111000000000000 | + | 0000000000000000 | + +------------------+ + +However, we need: + + L + +------------------+ + | 1111000000000000 | + | 2222000000000000 | + +------------------+ + +So we'll have to perform another matmul. +Overall this shrink wastes time and memory padding the LoRA adapters and running +extra matmuls. + +We can get both reduce the number of matmuls used and the amount of applied +padding by grouping the LoRA adapters into multiple "lanes". + + L + +------------------+ + D | 1111222200000000 | L + +------------------+ | 1111222200000000 | +------------------+ + | 1111111111111111 | | 1111222200000000 | | 1111222200000000 | + T | 2222222222222222 | x D | 1111222200000000 | = T | 1111222200000000 | + +------------------+ | 1111222200000000 | +------------------+ + | 1111222200000000 | + | 1111222200000000 | + | 1111222200000000 | + +------------------+ + + +Now we're able to compute the outputs of 4 different LoRA adapters in the same +8 cycles. However we don't need all these results so we'll again mask out +everything we don't need to get: + + L + +------------------+ + | 1111000000000000 | + | 0000222200000000 | + +------------------+ + +But now our outputs aren't aligned properly, so we would need to apply an extra +shuffle operation. + + +------------------+ + | Expand Operation | + +------------------+ + +When expanding we end up wasting space in both matrix registers. + + D + +------------------+ + L | 1111111111111111 | D + +------------------+ | 1111111111111111 | +------------------+ + | 1111000000000000 | | 1111111111111111 | | 1111111111111111 | + T | 2222000000000000 | x L | 1111111111111111 | = T | 1111111111111111 | + +------------------+ | 0000000000000000 | +------------------+ + | 0000000000000000 | + | 0000000000000000 | + | 0000000000000000 | + +------------------+ + +But, if we use LoRA Laning like before, we can waste less space. We would also +have to shuffle the input so it applies to the right adapter. + + D + +------------------+ + L | 1111111111111111 | D + +------------------+ | 1111111111111111 | +------------------+ + | 1111000000000000 | | 1111111111111111 | | 1111111111111111 | + T | 0000222200000000 | x L | 1111111111111111 | = T | 2222222222222222 | + +------------------+ | 2222222222222222 | +------------------+ + | 2222222222222222 | + | 2222222222222222 | + | 2222222222222222 | + +------------------+ + +Since this shuffling is the exact opposite of the operation we do at the end of +the Shrink operation, we can skip both shuffles. + +""" + +LORA_RANK_BLOCK_SIZE = 256 + + +def _bgmv_shrink_kernel(bT: int, bL: int, n_lora_lanes: int, lane_size: int, + max_num_loras: int, idx_ref, inp_ref, lora_ref, + out_ref, acc_ref, mask_ref, lanes_ref): + + t_idx = pl.program_id(0) + l_idx = pl.program_id(1) + d_idx = pl.program_id(2) + + @pl.when((t_idx == 0) & (l_idx == 0) & (d_idx == 0)) + def _(): + lanes_ref[...] = jnp.zeros_like(lanes_ref[...], dtype=jnp.float32) + ones = jnp.ones((lane_size, ), dtype=jnp.float32) + + for i in range(n_lora_lanes): + start = i * lane_size + end = start + lane_size + lanes_ref.at[i, start:end].set(ones) + + @pl.when(d_idx == 0) + def _(): + acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) + + if max_num_loras == 1 and n_lora_lanes == 1: + acc_ref[...] += jax.lax.dot_general(inp_ref[...], + lora_ref[0, ...], + (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) + else: + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + + def _mask_setup_step(i, valid): + idx = idx_ref[i + bT * t_idx] + inner_lane_idx = idx % n_lora_lanes + outer_lane_idx = idx // n_lora_lanes + + mask_ref.at[outer_lane_idx, i, :].set(lanes_ref[inner_lane_idx]) + + return valid | (1 << outer_lane_idx) + + valid = jax.lax.fori_loop(0, bT, _mask_setup_step, 0) + + def _lora_matmul_step(lane_idx, check_bit): + + @pl.when((valid & check_bit) > 0) + def _(): + acc_ref[...] += jax.lax.dot_general( + inp_ref[...], + lora_ref[lane_idx, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_ref[lane_idx, + ...] + + return check_bit << 1 + + _ = jax.lax.fori_loop(0, max_num_loras, _lora_matmul_step, 1) + + @pl.when(d_idx == pl.num_programs(2) - 1) + def _(): + out_ref[...] = acc_ref[...].astype(out_ref.dtype) + + +@functools.partial(jax.jit, + static_argnames=[ + "TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK", + "N_LORA_LANES", "LANE_SIZE" + ]) +def _bgmv_shrink( + idxs: jax.Array, # (T, ) int32 + inputs: jax.Array, # (T, D) model dtype + loras: jax.Array, # (N, L, D) model dtype + *, + TOKEN_BLOCK: int, + LORA_BLOCK: int, + DIM_BLOCK: int, + N_LORA_LANES: int, + LANE_SIZE: int) -> jax.Array: # (T, L) model dtype + T, D = inputs.shape + N, L, _ = loras.shape + + return pl.pallas_call( + kernel=functools.partial(_bgmv_shrink_kernel, TOKEN_BLOCK, LORA_BLOCK, + N_LORA_LANES, LANE_SIZE, N), + out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T // TOKEN_BLOCK, L // LORA_BLOCK, D // DIM_BLOCK), + in_specs=[ + pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), + lambda i, j, k, block_idx: (i, k)), + pl.BlockSpec((N, LORA_BLOCK, DIM_BLOCK), + lambda i, j, k, block_idx: (0, j, k)), + ], + out_specs=pl.BlockSpec((TOKEN_BLOCK, LORA_BLOCK), + lambda i, j, k, block_idx: (i, j)), + scratch_shapes=[ + pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), + pltpu.VMEM((N, TOKEN_BLOCK, LORA_BLOCK), jnp.float32), + pltpu.VMEM((N_LORA_LANES, LORA_BLOCK), jnp.float32) + ]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), + name="bgmv")(idxs, inputs, loras) + + +def bgmv_shrink_shape_function(idxs, inputs, loras): + T, _ = inputs.shape + _, L, _ = loras.shape + + return [((T, L), inputs.dtype)] + + +@impl(XLA_LIB, "bgmv_shrink", "XLA") +def bgmv_shrink_xla(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): + inputs = inputs.to(dtype=loras.dtype) + + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + T, _ = inputs.shape + N, L, D = loras.shape + + TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) + LORA_BLOCK = LORA_RANK_BLOCK_SIZE + DIM_BLOCK = largest_divisor(D, [256, 512, 1024]) + + # See if we can fit multiple LoRAs in a register. This would activate LoRA + # laning + N_LORA_LANES = math.ceil(LORA_BLOCK / L) + LANE_SIZE = min(L, LORA_BLOCK) + if N_LORA_LANES > 1 and N > 1: + pad_N = next_multiple_of(N, N_LORA_LANES) - N + new_N = N + pad_N + + loras = torch.nn.functional.pad(loras, (0, 0, 0, 0, 0, pad_N)) + loras = loras.reshape((new_N // N_LORA_LANES, LORA_BLOCK, D)) + N, L, D = loras.shape + + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU + # register. This has to happen in pytorch, doing it in Jax will lead to NaNs + pad_L = 0 + if LORA_BLOCK > L or L % LORA_BLOCK != 0: + pad_L = next_multiple_of(L, LORA_BLOCK) - L + pad_D = 0 + if DIM_BLOCK > D or D % DIM_BLOCK != 0: + pad_D = next_multiple_of(D, DIM_BLOCK) - D + + pad_T = 0 + if TOKEN_BLOCK > T or T % TOKEN_BLOCK != 0: + pad_T = next_multiple_of(T, TOKEN_BLOCK) - T + + if pad_D != 0 or pad_L != 0: + loras = torch.nn.functional.pad(loras, (0, pad_D, 0, pad_L, 0, 0)) + if pad_D != 0 or pad_T != 0: + inputs = torch.nn.functional.pad(inputs, (0, pad_D, 0, pad_T)) + if pad_T != T: + idxs = torch.nn.functional.pad(idxs, ((0, pad_T))) + + jax_import_guard() + kernel = make_kernel_from_pallas( + functools.partial(_bgmv_shrink, + TOKEN_BLOCK=TOKEN_BLOCK, + LORA_BLOCK=LORA_BLOCK, + DIM_BLOCK=DIM_BLOCK, + N_LORA_LANES=N_LORA_LANES, + LANE_SIZE=LANE_SIZE), bgmv_shrink_shape_function) + + return kernel(idxs, inputs, loras)[:T, :L] + + +@impl(XLA_LIB, "bgmv_shrink", "CompositeExplicitAutograd") +def bgmv_shrink_non_xla(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): + T, _ = inputs.shape + + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + N, L, _ = loras.shape + + LORA_BLOCK = LORA_RANK_BLOCK_SIZE + N_LORA_LANES = math.ceil(LORA_BLOCK / L) + if N_LORA_LANES > 1 and N > 1: + L = LORA_BLOCK + + return torch.empty((T, L), device=inputs.device) + + +# This kernel is similar to the one above but it assumes that the LoRA adapters +# have been pre-transposed. This lets us skip the data copies involved in +# transposing. +# We only need this for the expand op since the LoRA dimensions in the shrink op +# are small enough that the TPU can gather them without a data copy. +def _bgmv_expand_kernel(bT: int, bL: int, max_num_loras: int, idx_ref, inp_ref, + lora_ref, out_ref, acc_ref, mask_ref): + + @pl.when(pl.program_id(2) == 0) + def _(): + acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) + + if max_num_loras == 1: + acc_ref[...] += jax.lax.dot(inp_ref[...], + lora_ref[0, ...], + preferred_element_type=jnp.float32) + else: + t = pl.program_id(0) + + ones = jnp.ones((bL, ), dtype=jnp.float32) + + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + + def _mask_setup_step(i, valid): + idx = idx_ref[i + bT * t] + lane_idx = idx % max_num_loras + + mask_ref.at[lane_idx, i, :].set(ones) + return valid | (1 << lane_idx) + + valid = jax.lax.fori_loop(0, bT, _mask_setup_step, 0) + + def _lora_matmul_step(lane_idx, check_bit): + + @pl.when((valid & check_bit) > 0) + def _(): + acc_ref[...] += jax.lax.dot( + inp_ref[...], + lora_ref[lane_idx, ...], + preferred_element_type=jnp.float32) * mask_ref[lane_idx, + ...] + + return check_bit << 1 + + _ = jax.lax.fori_loop(0, max_num_loras, _lora_matmul_step, 1) + + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + def _(): + out_ref[...] = acc_ref[...].astype(out_ref.dtype) + + +@functools.partial(jax.jit, + static_argnames=["TOKEN_BLOCK", "LORA_BLOCK", "DIM_BLOCK"]) +def _bgmv_expand( + idxs: jax.Array, # (T, ) int32 + inputs: jax.Array, # (T, D) model dtype + loras: jax.Array, # (N, L, D) model dtype + *, + TOKEN_BLOCK: int, + LORA_BLOCK: int, + DIM_BLOCK: int) -> jax.Array: # (T, L) model dtype + T, D = inputs.shape + N, _, L = loras.shape + + return pl.pallas_call( + kernel=functools.partial(_bgmv_expand_kernel, TOKEN_BLOCK, LORA_BLOCK, + N), + out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T // TOKEN_BLOCK, L // LORA_BLOCK, D // DIM_BLOCK), + in_specs=[ + pl.BlockSpec((TOKEN_BLOCK, DIM_BLOCK), + lambda i, j, k, block_idx: (i, k)), + pl.BlockSpec((N, DIM_BLOCK, LORA_BLOCK), + lambda i, j, k, block_idx: (0, k, j)), + ], + out_specs=pl.BlockSpec((TOKEN_BLOCK, LORA_BLOCK), + lambda i, j, k, block_idx: (i, j)), + scratch_shapes=[ + pltpu.VMEM((TOKEN_BLOCK, LORA_BLOCK), jnp.float32), + pltpu.VMEM((N, TOKEN_BLOCK, LORA_BLOCK), jnp.float32) + ]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), + name="bgmv_expand")(idxs, inputs, loras) + + +def bgmv_expand_shape_function(idxs, inputs, loras): + T, _ = inputs.shape + _, _, L = loras.shape + + return [((T, L), inputs.dtype)] + + +@impl(XLA_LIB, "bgmv_expand", "XLA") +def bgmv_expand_xla(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): + inputs = inputs.to(dtype=loras.dtype) + + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + T, DI = inputs.shape + N, D, L = loras.shape + + TOKEN_BLOCK = get_bounded_value(16, next_multiple_of(T, 16), 128) + LORA_BLOCK = largest_divisor(L, [256, 512, 1024]) + DIM_BLOCK = LORA_RANK_BLOCK_SIZE + + # See if we can fit multiple LoRAs in a register. This would activate LoRA + # laning + N_LORA_LANES = math.ceil(DIM_BLOCK / D) + if D != DI and N_LORA_LANES > 1 and N > 1: + pad_N = next_multiple_of(N, N_LORA_LANES) - N + new_N = N + pad_N + + loras = torch.nn.functional.pad(loras, (0, 0, 0, 0, 0, pad_N)) + loras = loras.reshape((new_N // N_LORA_LANES, DIM_BLOCK, L)) + idxs = idxs // N_LORA_LANES + N, D, L = loras.shape + + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU + # register. This has to happen in pytorch, doing it in Jax will lead to NaNs + pad_L = 0 + if LORA_BLOCK > L or L % LORA_BLOCK != 0: + pad_L = next_multiple_of(L, LORA_BLOCK) - L + + pad_D = 0 + if DIM_BLOCK > D or D % DIM_BLOCK != 0: + pad_D = next_multiple_of(D, DIM_BLOCK) - D + + pad_T = 0 + if TOKEN_BLOCK > T or T % TOKEN_BLOCK != 0: + pad_T = next_multiple_of(T, TOKEN_BLOCK) - T + + if pad_D != 0 or pad_L != 0: + loras = torch.nn.functional.pad(loras, (0, pad_L, 0, pad_D, 0, 0)) + if pad_D != 0 or pad_T != 0: + inputs = torch.nn.functional.pad(inputs, (0, pad_D, 0, pad_T)) + if pad_T != T: + idxs = torch.nn.functional.pad(idxs, ((0, pad_T))) + + jax_import_guard() + + kernel = make_kernel_from_pallas( + functools.partial(_bgmv_expand, + TOKEN_BLOCK=TOKEN_BLOCK, + LORA_BLOCK=LORA_BLOCK, + DIM_BLOCK=DIM_BLOCK), bgmv_expand_shape_function) + + return kernel(idxs, inputs, loras)[:T, :L] + + +@impl(XLA_LIB, "bgmv_expand", "CompositeExplicitAutograd") +def bgmv_expand_non_xla(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): + T, _ = inputs.shape + + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + _, _, L = loras.shape + + return torch.empty((T, L), device=inputs.device) + + +def largest_divisor(n: int, divs: List[int]) -> int: + for div in sorted(divs, reverse=True): + if n % div == 0: + return div + return max(divs) + + +def next_multiple_of(n: int, mult: int) -> int: + return math.ceil(n / mult) * mult + + +def get_bounded_value(_min: int, val: int, _max: int) -> int: + return min(max(_min, val), _max) diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 94fa3f27ab60..570cd1b756a9 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -48,7 +48,7 @@ def add_shrink( lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. """ @@ -66,7 +66,7 @@ def add_expand( offset_start: int = 0, add_inputs=True, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Performs GEMM and bias addition for multiple slices of lora_b. """ @@ -80,7 +80,7 @@ def add_lora_embedding( lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA, and this layer only requires the expand operation. @@ -98,7 +98,7 @@ def add_lora_linear(self, output_slices: Tuple[int, ...], *, buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. """ @@ -114,7 +114,7 @@ def add_lora_logits(self, scale, *, buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. """ @@ -342,7 +342,7 @@ def update_metadata( @abstractmethod def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, **kwargs) -> None: + scale: float, **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. @@ -369,7 +369,7 @@ def add_expand(self, output_slices: Tuple[int, ...], offset_start: int = 0, add_inputs=True, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM and bias addition for multiple slices of lora_b. @@ -401,7 +401,7 @@ def add_lora_embedding(self, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. and this layer only requires the expand operation. @@ -428,7 +428,7 @@ def add_lora_linear(self, output_slices: Tuple[int, ...], *, buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. @@ -463,7 +463,7 @@ def add_lora_logits(self, scale, *, buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py new file mode 100644 index 000000000000..6cc98f683746 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -0,0 +1,405 @@ +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch_xla.core.xla_model as xm + +from vllm.lora.ops.xla_ops import (LORA_RANK_BLOCK_SIZE, bgmv_expand, + bgmv_expand_slice, bgmv_shrink) +from vllm.lora.punica_wrapper.utils import convert_mapping + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext + +from .punica_base import PunicaWrapperBase + + +class PunicaWrapperTPU(PunicaWrapperBase): + """ + PunicaWrapperTPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + # PunicaWrapperBase defines some tensors with dtype=torch.int64, which + # isn't supported by the TPU. So convert those tensors to int32. + # Not all of them are used by the TPU so only convert the useful ones. + self._token_lora_indices = self._token_lora_indices.to( + dtype=torch.int32) + self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) + self._sampler_indices_padded = self._sampler_indices_padded.to( + dtype=torch.int32) + + torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, + True) + torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True) + torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, + True) + + def mark_compiled(self): + torch._dynamo.mark_dynamic(self._token_lora_indices, 0) + torch._dynamo.mark_dynamic(self._embeddings_indices, 1) + torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) + + def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: + return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + return self._embeddings_indices[:] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + return self._sampler_indices_padded[:] + + def shrink( + self, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale) + + def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, + add_inputs: bool): + return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), + add_inputs) + + def expand_slice(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, + add_inputs: bool) -> torch.Tensor: + return bgmv_expand_slice(x, w_t_all, y, + self._get_token_lora_indices(x), y_offset, + y_slice_size, add_inputs) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs) -> Optional[torch.Tensor]: + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + torch.ops.xla.dynamo_set_buffer_donor_(y, True) + x = x.view(-1, x.shape[-1]) + + for slice_idx in range(len(lora_a_stacked)): + lora_s = lora_a_stacked[slice_idx] + y_s = self.shrink(x, lora_s, scale) + y[slice_idx, :, :] = y_s # type: ignore[index] + return y + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> torch.Tensor: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = 0 + + if lora_bias_stacked is not None: + y = self._apply_bias(self._get_token_lora_indices(y), y, + output_slices, lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + y = self.expand_slice(y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs) + offset_left += output_slices[slice_idx] + return y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> torch.Tensor: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only needs the expand op + return self.expand(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> torch.Tensor: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will not be changed in-place. + x (torch.Tensor): Input tensor (T, E) + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self._get_token_lora_indices(y), y, + output_slices, lora_bias_stacked) + + if buffer is None: + r = max(lora_b_stacked[0].size(-1), LORA_RANK_BLOCK_SIZE) + T = x.size(0) + buffer = torch.zeros( + (len(output_slices), T, r), + dtype=x.dtype, + device=x.device, + ) + buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + return self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + + sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) + buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale) + y = bgmv_expand(buffer, + lora_b_stacked, + y, + sampler_indices, + add_inputs=True) + return y.view_as(y_org) + + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias = torch.where(indices[:, None] == -1, 0, bias) + + bias = F.pad(bias, (offset_left, output.shape[1] - + (offset_left + slice), 0, 0)) + + output += bias + offset_left += slice + + return output.view_as(org_output) + + # This performs the same tensor ops as the base method, except it does them + # on the CPU then transfers the results to the TPU + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + # Make sure we don't accidentally collect outside operations + xm.mark_step() + + # Pad the prompt mapping to avoid running into recompiles on the TPU + # TODO: Should this happen inside mapping internally? If so how can we + # avoid having backend specific LoRAMapping classes? + mapping.prompt_mapping = self._pad_prompt_mapping( + mapping.prompt_mapping) + + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + "cpu", + long_lora_context, + ) + self._token_lora_indices = self._pad_to_shape( + base_indices, self._token_lora_indices.shape, dims=1 + ).to(self.device) + self._sampler_indices = self._pad_to_shape( + sampler_indices, self._sampler_indices.shape, dims=1 + ).to(self.device) + self._sampler_indices_padded = self._pad_to_shape( + sampler_indices_padded, self._sampler_indices_padded.shape, dims=1 + ).to(self.device) + self._embeddings_indices = self._pad_to_shape( + embeddings_indices, self._embeddings_indices.shape, dims=2 + ).to(self.device) + if long_lora_offsets_tensor is not None: + self._long_lora_indices = self._pad_to_shape( + long_lora_offsets_tensor, self._long_lora_indices.shape, dims=1 + ).to(self.device) + else: + zeroed = torch.zeros_like( + self._long_lora_indices.cpu(), dtype=torch.int32 + ) + self._long_lora_indices = zeroed.to(self.device) + self.indices_len[:] = indices_len + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + self.batch_size = 1 + self._lora_indices_per_batch[:self. + batch_size] = token_lora_tensor[:self. + batch_size] + + def _pad_prompt_mapping( + self, prompt_mapping: Tuple[int, ...]) -> Tuple[int, ...]: + num_reqs = len(prompt_mapping) + + # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular + # import + MIN_NUM_SEQS = 8 + + padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS) + pad_len = padded_num_reqs - num_reqs + + padding = [-1] * pad_len + return tuple(list(prompt_mapping) + padding) + + def _pad_to_shape(self, src, target_shape, dims=1): + if dims == 1: + pad_len = target_shape[0] - src.shape[0] + return F.pad(src, (0, pad_len), value=0).to(torch.int32) + else: + pad_rows = target_shape[0] - src.shape[0] + pad_cols = target_shape[1] - src.shape[1] + return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32) + diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index dbc2d27c597f..bb9c606cda98 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -125,11 +125,13 @@ def convert_mapping( indices[2] * extra_vocab_size, indices[2] * (vocab_size + extra_vocab_size), ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 + embeddings_indices = torch.where(embeddings_indices == -1, + embeddings_indices, max_loras - 1) base_indices = indices[1] sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.where(sampler_indices_padded == -1, + sampler_indices_padded, max_loras - 1) sampler_indices_padded = torch.arange( 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( sampler_indices_padded * len(sampler_indices_padded)) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 31a7ffbd910d..d25784e4a085 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -328,6 +328,27 @@ def get_punica_wrapper(cls) -> str: """ raise NotImplementedError + @classmethod + def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + """ + Return the platform specific values for (-inf, inf) + """ + return float("-inf"), float("inf") + + @classmethod + def can_update_inplace(cls) -> bool: + """ + Checks if the platform allows inplace memory updates + """ + return True + + @classmethod + def get_lora_vocab_padding_size(cls) -> int: + """ + Returns how much padding the LoRA logits need for kernels + """ + return 256 + @classmethod def get_device_communicator_cls(cls) -> str: """ diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 61e84a6d6f95..166cbbdb57fb 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union import torch @@ -66,6 +66,22 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return not envs.VLLM_USE_V1 + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" + + @classmethod + def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + return torch.finfo(dtype).min, torch.finfo(dtype).max + + @classmethod + def can_update_inplace(cls): + return False + + @classmethod + def get_lora_vocab_padding_size(cls) -> int: + return 1 + @classmethod def inference_mode(cls): return torch.no_grad() diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index a8a19e0e6206..35af83e9cc80 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -84,8 +84,38 @@ def set_active_loras(self, input_batch: InputBatch, lora_requests) @contextmanager - def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, - num_scheduled_tokens: np.ndarray): + def maybe_setup_dummy_loras(self, lora_config): + if lora_config is None: + yield + else: + # __enter__ code + assert self.lora_manager is not None, "LoRA is not enabled" + + num_loras = lora_config.max_loras + + # Make dummy lora requests + lora_requests: set[LoRARequest] = { + LoRARequest(lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path") + for lora_id in range(1, num_loras + 1) + } + + with self.lora_manager.dummy_lora_cache(): + # Add the dummy LoRAs here so _set_active_loras doesn't try to + # load from disk. + for lr in lora_requests: + self.lora_manager.add_dummy_lora( + lr, rank=self.LORA_WARMUP_RANK) + + yield + + # __exit__ code + self.lora_manager.remove_all_adapters() + + @contextmanager + def maybe_select_dummy_loras(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray): if lora_config is None: yield else: @@ -112,21 +142,18 @@ def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, for lora_id in range(1, num_loras + 1) } - with self.lora_manager.dummy_lora_cache(): - # Add the dummy LoRAs here so _set_active_loras doesn't try to - # load from disk. - for lr in lora_requests: - self.lora_manager.add_dummy_lora( - lr, rank=self.LORA_WARMUP_RANK) - - self._set_active_loras(tuple(prompt_lora_mapping), - tuple(token_lora_mapping), - lora_requests) + self._set_active_loras(tuple(prompt_lora_mapping), + tuple(token_lora_mapping), lora_requests) - yield + yield - # __exit__ code - self.lora_manager.remove_all_adapters() + @contextmanager + def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray): + with self.maybe_setup_dummy_loras( + lora_config), self.maybe_select_dummy_loras( + lora_config, num_scheduled_tokens): + yield def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 773c426474fc..d427fa32ac30 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -19,6 +19,8 @@ from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.lora.ops.xla_ops import LORA_RANK_BLOCK_SIZE +from vllm.lora.request import LoRARequest from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -37,9 +39,11 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from vllm.lora.layers import BaseLayerWithLoRA if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -54,7 +58,7 @@ MIN_NUM_SEQS = 8 -class TPUModelRunner: +class TPUModelRunner(LoRAModelRunnerMixin): def __init__( self, @@ -184,6 +188,18 @@ def __init__( self.num_reqs_paddings = _get_req_paddings( min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) + if self.lora_config is not None: + # This makes us pad at initialisation time so we can avoid padding + # at runtime, which introduces long stalls + self.lora_config.max_lora_rank = _get_padded_lora_rank( + self.lora_config.max_lora_rank, self.lora_config.max_loras) + + if self.lora_config is not None: + # This makes us pad at initialisation time so we can avoid padding + # at runtime, which introduces long stalls + self.lora_config.max_lora_rank = _get_padded_lora_rank( + self.lora_config.max_lora_rank, self.lora_config.max_loras) + def _update_num_xla_graphs(self, case_str): check_comp = self.check_recompilation and not self.enforce_eager if not check_comp: @@ -510,6 +526,17 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Padded to avoid recompiling when `num_reqs` varies. logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) + + if self.lora_config is not None: + # We need to respect padding when activating LoRA adapters + padded_num_scheduled_tokens_per_req = np.copy( + num_scheduled_tokens_per_req + ) # Copying to avoid accidental state corruption bugs + padded_num_scheduled_tokens_per_req[-1] += \ + padded_total_num_scheduled_tokens - total_num_scheduled_tokens + + self.set_active_loras(self.input_batch, + padded_num_scheduled_tokens_per_req) return attn_metadata, logits_indices, padded_num_reqs def _scatter_placeholders( @@ -797,6 +824,15 @@ def load_model(self) -> None: "get_tensor_model_parallel_rank", return_value=xm_tp_rank): model = get_model(vllm_config=self.vllm_config) + if self.lora_config is not None: + model = self.load_lora_model(model, self.model_config, + self.scheduler_config, + self.lora_config, self.device) + replace_set_lora(model) + punica_wrapper = self.lora_manager._adapter_manager.punica_wrapper + if not self.enforce_eager: + punica_wrapper.mark_compiled() + # Sync all pending XLA execution during model initialization and weight # loading. xm.mark_step() @@ -846,6 +882,8 @@ def _dummy_run(self, num_tokens: int) -> None: num_seqs=num_seqs, ) + xm.mark_step() # Capture tensors created when setting up + if self.is_multimodal_model: torch._dynamo.mark_dynamic(inputs_embeds, 0) else: @@ -853,12 +891,22 @@ def _dummy_run(self, num_tokens: int) -> None: torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - with set_forward_context(attn_metadata, self.vllm_config, 0): + with self.maybe_select_dummy_loras( + self.lora_config, + np.array([num_tokens], dtype=np.int32)), set_forward_context( + attn_metadata, self.vllm_config, 0): out = self.model(input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds) self._hidden_states_dtype = out.dtype + def _set_active_loras(self, prompt_lora_mapping, token_lora_mapping, + lora_requests) -> None: + xm.mark_step() # Captures input updates + super()._set_active_loras(prompt_lora_mapping, token_lora_mapping, + lora_requests) + xm.mark_step() # Captures metadata updates + def _precompile_backbone(self) -> None: logger.info("Compiling the model with different input shapes.") @@ -915,7 +963,10 @@ def _precompile_sample_from_hidden(self) -> None: generate_params_if_all_greedy, )) sampling_metadata.all_greedy = all_greedy - self.sample_from_hidden(dummy_hidden, sampling_metadata) + with self.maybe_select_dummy_loras( + self.lora_config, np.array([num_reqs], + dtype=np.int32)): + self.sample_from_hidden(dummy_hidden, sampling_metadata) logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() @@ -927,9 +978,10 @@ def capture_model(self) -> None: Precompile all the subgraphs with possible input shapes. """ # TODO: precompile encoder - self._precompile_backbone() - self._precompile_select_hidden_states() - self._precompile_sample_from_hidden() + with self.maybe_setup_dummy_loras(self.lora_config): + self._precompile_backbone() + self._precompile_select_hidden_states() + self._precompile_sample_from_hidden() def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -1009,7 +1061,7 @@ def get_multimodal_embeddings(self, *args, **kwargs): def get_input_embeddings(self, *args, **kwargs): return self.model.get_input_embeddings(*args, **kwargs) - + def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: logger.info("Preparing request paddings:") @@ -1033,11 +1085,11 @@ def _get_token_paddings(min_token_size: int, max_token_size: int, padding_gap: int) -> list[int]: """Generate a list of padding size, starting from min_token_size, ending with a number that can cover max_token_size - + If padding_gap == 0 then: increase 2X each time (exponential) else: - first increase the size to twice, + first increase the size to twice, then increase the padding size by padding_gap. """ # assert min_token_size is power of 2 @@ -1073,3 +1125,50 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: index = bisect.bisect_left(paddings, x) assert index < len(paddings) return paddings[index] + + +def _get_padded_lora_rank(max_lora_rank: int, max_num_loras: int) -> int: + max_num_loras += 1 + + # If we have enough LoRAs to use laning without padding + if max_lora_rank * max_num_loras >= LORA_RANK_BLOCK_SIZE: + return max_lora_rank + + return 1 << (LORA_RANK_BLOCK_SIZE // max_num_loras).bit_length() + + +def _create_dummy_scheduled_tokens(total_tokens: int, + num_prompts: int) -> np.ndarray: + assert num_prompts <= total_tokens, "Expected num_prompts < total_tokens" + base_tokens = total_tokens // num_prompts + leftover_tokens = total_tokens % num_prompts + + tokens = np.full((num_prompts, ), base_tokens, dtype=np.int32) + tokens[-1] += leftover_tokens + + return tokens + +def replace_set_lora(model): + def _tpu_set_lora( + self, + idx: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None + ): + index = torch.tensor([idx], dtype=torch.int32, device="xla") + self._original_set_lora(index, lora_a, lora_b, embeddings_tensor, bias) + xm.mark_step() + + def _tpu_reset_lora(self, idx: int): + index = torch.tensor([idx], dtype=torch.int32, device="xla") + self._original_reset_lora(index) + xm.mark_step() + + for _, module in model.named_modules(): + if isinstance(module, BaseLayerWithLoRA): + module._original_set_lora = module.set_lora + module._original_reset_lora = module.reset_lora + module.set_lora = _tpu_set_lora.__get__(module, module.__class__) + module.reset_lora = _tpu_reset_lora.__get__(module, module.__class__) \ No newline at end of file diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 73c43969b87b..99ceb61bad60 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -15,6 +15,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput @@ -156,8 +157,9 @@ def determine_available_memory(self) -> int: self.vllm_config.compilation_config.static_forward_context, runner_kv_caches) - self.model_runner._dummy_run( - self.scheduler_config.max_num_batched_tokens) + with self.model_runner.maybe_setup_dummy_loras(self.lora_config): + self.model_runner._dummy_run( + self.scheduler_config.max_num_batched_tokens) # Synchronize before measuring the memory usage. xm.wait_device_ops() @@ -207,6 +209,9 @@ def profile(self, is_start: bool = True): else: xp.stop_trace() + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + def load_model(self) -> None: self.model_runner.load_model() diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 71b4b38fb9d6..2e9fe3d6a4d4 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -54,6 +54,11 @@ def __init__( if self.model_config.seed is None: self.model_config.seed = 0 + if vllm_config.lora_config is not None: + raise NotImplementedError( + """The V0 TPU backend doesn't support LoRA serving, please try \ + V1 by setting VLLM_USE_V1=1""") + def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" torch.set_grad_enabled(False)