From c8a8c5b180945e571a81a4ab31f0808941075895 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 11 Dec 2024 12:48:15 +0000 Subject: [PATCH 01/23] Added Multi-LoRA implementation for the CPU backend Signed-off-by: Akshat Tripathi --- tests/lora/conftest.py | 32 +++-- tests/lora/test_layers.py | 41 ++++-- tests/lora/test_lora_manager.py | 20 +-- tests/lora/test_mixtral.py | 4 +- tests/lora/test_punica_sizes.py | 19 ++- tests/lora/test_punica_variation.py | 52 ++++--- tests/lora/test_quant_model.py | 3 +- vllm/lora/ops/__init__.py | 18 +++ vllm/lora/ops/default/__init__.py | 0 vllm/lora/ops/default/lora_ops.py | 113 +++++++++++++++ vllm/lora/ops/triton/__init__.py | 0 vllm/lora/ops/{ => triton}/bgmv_expand.py | 0 .../ops/{ => triton}/bgmv_expand_slice.py | 0 vllm/lora/ops/{ => triton}/bgmv_shrink.py | 0 vllm/lora/ops/{ => triton}/sgmv_expand.py | 0 .../ops/{ => triton}/sgmv_expand_slice.py | 0 vllm/lora/ops/{ => triton}/sgmv_shrink.py | 0 vllm/lora/ops/{ => triton}/utils.py | 0 vllm/lora/punica_wrapper/punica_cpu.py | 14 ++ vllm/lora/punica_wrapper/punica_gpu.py | 14 +- vllm/lora/punica_wrapper/punica_selector.py | 5 + vllm/worker/cpu_model_runner.py | 133 ++++++++++++++++-- vllm/worker/cpu_worker.py | 20 ++- 23 files changed, 400 insertions(+), 88 deletions(-) create mode 100644 vllm/lora/ops/default/__init__.py create mode 100644 vllm/lora/ops/default/lora_ops.py create mode 100644 vllm/lora/ops/triton/__init__.py rename vllm/lora/ops/{ => triton}/bgmv_expand.py (100%) rename vllm/lora/ops/{ => triton}/bgmv_expand_slice.py (100%) rename vllm/lora/ops/{ => triton}/bgmv_shrink.py (100%) rename vllm/lora/ops/{ => triton}/sgmv_expand.py (100%) rename vllm/lora/ops/{ => triton}/sgmv_expand_slice.py (100%) rename vllm/lora/ops/{ => triton}/sgmv_shrink.py (100%) rename vllm/lora/ops/{ => triton}/utils.py (100%) create mode 100644 vllm/lora/punica_wrapper/punica_cpu.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 29ecf3780820..df04c3e0390c 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -20,6 +20,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model +from vllm.platforms import current_platform class ContextIDInfo(TypedDict): @@ -64,13 +65,16 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): @pytest.fixture def dist_init(): temp_file = tempfile.mkstemp()[1] - init_distributed_environment( - world_size=1, - rank=0, - distributed_init_method=f"file://{temp_file}", - local_rank=0, - backend="nccl", - ) + + backend = "nccl" + if current_platform.is_cpu(): + backend = "gloo" + + init_distributed_environment(world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend=backend) initialize_model_parallel(1, 1) yield cleanup_dist_env_and_memory(shutdown_ray=True) @@ -80,13 +84,15 @@ def dist_init(): def dist_init_torch_only(): if torch.distributed.is_initialized(): return + backend = "nccl" + if current_platform.is_cpu(): + backend = "gloo" + temp_file = tempfile.mkstemp()[1] - torch.distributed.init_process_group( - backend="nccl", - world_size=1, - rank=0, - init_method=f"file://{temp_file}", - ) + torch.distributed.init_process_group(world_size=1, + rank=0, + init_method=f"file://{temp_file}", + backend=backend) @pytest.fixture diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index fb8c0b2a7ba2..2f5dddf8c946 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -48,11 +48,19 @@ torch.float32: (5e-3, 5e-3), torch.bfloat16: (3e-2, 2e-2), } -# TODO: Modify this based on platform -DEVICES = [ + +pytestmark = pytest.mark.skipif( + not (current_platform.is_cuda_alike() or current_platform.is_cpu()), + reason="Backend not supported") + +CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +CPU_DEVICES = ["cpu"] + +DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES + #For GPU, we will launch different triton kernels between the prefill and decode # stages, so we need to verify this. prefill stage(True) or decode stage(False) STAGES = [True, False] @@ -198,6 +206,10 @@ def check_punica_wrapper(punica_wrapper) -> bool: from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU return type(punica_wrapper) is PunicaWrapperGPU + elif current_platform.is_cpu(): + from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU + + return type(punica_wrapper) is PunicaWrapperCPU else: return False @@ -211,7 +223,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA # device, see: https://github.com/triton-lang/triton/issues/2925 # Same below. - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 @@ -313,7 +326,9 @@ def create_random_embedding_layer(): def test_embeddings_with_new_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device) @@ -450,7 +465,9 @@ def create_random_embedding_layer(): def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, stage) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device) @@ -582,7 +599,9 @@ def _pretest(): def test_linear_replicated(dist_init, num_loras, device, stage, bias_enabled) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device) assert check_punica_wrapper(punica_wrapper) @@ -695,7 +714,9 @@ def create_random_linear_replicated_layer(): def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, device, stage, bias_enabled) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device) assert check_punica_wrapper(punica_wrapper) @@ -818,7 +839,9 @@ def create_random_linear_parallel_layer(): def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, device, stage, bias_enabled) -> None: - torch.cuda.set_device(device) + if current_platform.is_cuda_alike(): + torch.cuda.set_device(device) + torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device) assert check_punica_wrapper(punica_wrapper) @@ -971,6 +994,8 @@ class FakeConfig: @pytest.mark.parametrize("rotary_dim", [None, 32]) @pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("seq_len", [11, 1024]) +@pytest.mark.skipif(not current_platform.is_cuda_alike(), + reason="Only CUDA backends are supported") def test_rotary_embedding_long_context(dist_init, num_loras, device, scaling_factors, max_position, is_neox_style, rotary_dim, head_size, diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 0b76f466702f..927d8392452b 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -19,6 +19,7 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, WorkerLoRAManager) from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.platforms import current_platform EMBEDDING_MODULES = { "embed_tokens": "input_embeddings", @@ -30,6 +31,9 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +CPU_DEVICES = ["cpu"] + +DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES def test_peft_helper(sql_lora_files): @@ -77,7 +81,7 @@ def test_peft_helper(sql_lora_files): PEFTHelper.from_dict(config) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_from_lora_tensors(sql_lora_files, device): tensors = load_file( os.path.join(sql_lora_files, "adapter_model.safetensors")) @@ -165,7 +169,7 @@ def test_replace_submodules(dist_init, dummy_model): manager = LoRAModelManager( model, 1, 1, 1, LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8), - torch.device("cuda")) + torch.device(DEVICES[0])) model = manager.model assert isinstance(model.get_submodule("dense1"), @@ -177,7 +181,7 @@ def test_replace_submodules(dist_init, dummy_model): RowParallelLinearWithLoRA) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lora_model_manager(dist_init, dummy_model, device): model = dummy_model model.supported_lora_modules = ["dense1", "dense2", "lm_head"] @@ -238,7 +242,7 @@ def test_lora_model_manager(dist_init, dummy_model, device): assert manager.punica_wrapper.device == device -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): model = dummy_model model.supported_lora_modules = ["dense1", "dense2", "lm_head"] @@ -330,7 +334,7 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): assert manager.device == device -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lru_lora_model_manager(dist_init, dummy_model, device): # This tests just the LRU cache functionality, everything else is # tested in test_lora_model_manager @@ -460,7 +464,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): assert manager.device == device -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) @@ -539,7 +543,7 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, device) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, sql_lora_files, device): # Should remove every LoRA not specified in the request. @@ -615,7 +619,7 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, device) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_packed_loras(dist_init, dummy_model_gate_up, device): model = dummy_model_gate_up model.supported_lora_modules = ["gate_up_proj"] diff --git a/tests/lora/test_mixtral.py b/tests/lora/test_mixtral.py index dddc299da446..31237acd549e 100644 --- a/tests/lora/test_mixtral.py +++ b/tests/lora/test_mixtral.py @@ -5,6 +5,7 @@ import vllm from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1" @@ -31,7 +32,8 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int, @pytest.mark.parametrize("tp_size", [4]) def test_mixtral_lora(mixtral_lora_files, tp_size): """Original test, the LoRA model has the common target modules, not all""" - if torch.cuda.device_count() < tp_size: + if torch.cuda.device_count( + ) < tp_size and tp_size > 1 and current_platform.is_cuda_alike(): pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") prompts = [ diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 66b5f82bbb97..af31748f8cbd 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -7,12 +7,8 @@ import pytest import torch -from vllm.lora.ops.bgmv_expand import bgmv_expand -from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice -from vllm.lora.ops.bgmv_shrink import bgmv_shrink -from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice -from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, sgmv_shrink) from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, @@ -110,7 +106,10 @@ MAX_RANKS = [32] SCALES = [0.5] SEED = [0] -CUDA_DEVICES = [f"cuda:{0}"] + +CUDA_DEVICES = ["cuda:0"] +CPU_DEVICES = ["cpu"] +DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES def assert_close(a, b): @@ -130,7 +129,7 @@ def assert_close(a, b): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( batches: int, num_loras: int, @@ -220,7 +219,7 @@ def test_punica_sgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_bgmv( batches: int, num_loras: int, @@ -294,7 +293,7 @@ def test_punica_bgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_expand_nslices( batches: int, num_loras: int, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 3b20033271d2..f280c758c6fd 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -6,13 +6,30 @@ import pytest import torch -# Enable custom op register -import vllm.lora.ops.bgmv_expand -import vllm.lora.ops.bgmv_expand_slice -import vllm.lora.ops.bgmv_shrink -import vllm.lora.ops.sgmv_expand -import vllm.lora.ops.sgmv_expand_slice -import vllm.lora.ops.sgmv_shrink # noqa: F401 +from vllm.triton_utils import HAS_TRITON + +# Enable custom op register if we're using custom ops +if HAS_TRITON: + import vllm.lora.ops.triton.bgmv_expand + import vllm.lora.ops.triton.bgmv_expand_slice + import vllm.lora.ops.triton.bgmv_shrink + import vllm.lora.ops.triton.sgmv_expand + import vllm.lora.ops.triton.sgmv_expand_slice + import vllm.lora.ops.triton.sgmv_shrink # noqa: F401 + + # Unlike test_punica_sizes.py, we directly utilize custom op for + # testing, which verifies the correct registration of these ops. + bgmv_expand = torch.ops.vllm.bgmv_expand + bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice + bgmv_shrink = torch.ops.vllm.bgmv_shrink + sgmv_expand = torch.ops.vllm.sgmv_expand + sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice + sgmv_shrink = torch.ops.vllm.sgmv_shrink +else: + from vllm.lora.ops.default.lora_ops import ( # type: ignore + bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) + from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, @@ -26,7 +43,10 @@ MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256] SCALES = [0.5] SEED = [0] -CUDA_DEVICES = [f"cuda:{0}"] + +CUDA_DEVICES = ["cuda:0"] +CPU_DEVICES = ["cpu"] +DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES def assert_close(a, b): @@ -38,16 +58,6 @@ def assert_close(a, b): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) -# Unlike test_punica_sizes.py, we directly utilize custom op for -# testing, which verifies the correct registration of these ops. -bgmv_expand = torch.ops.vllm.bgmv_expand -bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice -bgmv_shrink = torch.ops.vllm.bgmv_shrink -sgmv_expand = torch.ops.vllm.sgmv_expand -sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice -sgmv_shrink = torch.ops.vllm.sgmv_shrink - - @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) @@ -56,7 +66,7 @@ def assert_close(a, b): @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( batches: int, num_loras: int, @@ -146,7 +156,7 @@ def test_punica_sgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_bgmv( batches: int, num_loras: int, @@ -222,7 +232,7 @@ def test_punica_bgmv( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) @pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_punica_expand_nslices( batches: int, num_loras: int, diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 5432fa4ad0d3..c2590594a277 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -72,7 +72,8 @@ def format_prompt_tuples(prompt): @pytest.mark.parametrize("tp_size", [1]) def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model, tp_size): - if num_gpus_available < tp_size: + if num_gpus_available < tp_size and \ + tp_size > 1 and current_platform.is_cuda_alike(): pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") llm = vllm.LLM( diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py index e69de29bb2d1..7a13eabeb607 100644 --- a/vllm/lora/ops/__init__.py +++ b/vllm/lora/ops/__init__.py @@ -0,0 +1,18 @@ +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.triton.bgmv_expand import bgmv_expand + from vllm.lora.ops.triton.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.triton.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.triton.sgmv_expand import sgmv_expand + from vllm.lora.ops.triton.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.triton.sgmv_shrink import sgmv_shrink +else: + from vllm.lora.ops.default.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) + +__all__ = [ + "bgmv_expand", "bgmv_expand_slice", "bgmv_shrink", "sgmv_expand", + "sgmv_expand_slice", "sgmv_shrink" +] diff --git a/vllm/lora/ops/default/__init__.py b/vllm/lora/ops/default/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/lora/ops/default/lora_ops.py b/vllm/lora/ops/default/lora_ops.py new file mode 100644 index 000000000000..5f5aafd51615 --- /dev/null +++ b/vllm/lora/ops/default/lora_ops.py @@ -0,0 +1,113 @@ +import torch + + +def sgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, + add_inputs) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + if add_inputs: + output_tensor[:, :outputs.shape[1]] += outputs[:limit, :] + else: + output_tensor[:, :outputs.shape[1]] = outputs[:limit, :] + + +def sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, + scaling) + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + output_tensor[:, :outputs.shape[1]] = scaling * outputs[:] + + +def sgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False): + exploded_indices = torch.repeat_interleave(lora_indices_tensor, + seq_len_tensor) + + bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices, + slice_offset, slice_size, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + selected_loras = lora_b_weights[lora_indices_tensor].to( + dtype=output_tensor.dtype) + inputs = inputs.to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:] diff --git a/vllm/lora/ops/triton/__init__.py b/vllm/lora/ops/triton/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/triton/bgmv_expand.py similarity index 100% rename from vllm/lora/ops/bgmv_expand.py rename to vllm/lora/ops/triton/bgmv_expand.py diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/triton/bgmv_expand_slice.py similarity index 100% rename from vllm/lora/ops/bgmv_expand_slice.py rename to vllm/lora/ops/triton/bgmv_expand_slice.py diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/triton/bgmv_shrink.py similarity index 100% rename from vllm/lora/ops/bgmv_shrink.py rename to vllm/lora/ops/triton/bgmv_shrink.py diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/triton/sgmv_expand.py similarity index 100% rename from vllm/lora/ops/sgmv_expand.py rename to vllm/lora/ops/triton/sgmv_expand.py diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/triton/sgmv_expand_slice.py similarity index 100% rename from vllm/lora/ops/sgmv_expand_slice.py rename to vllm/lora/ops/triton/sgmv_expand_slice.py diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/triton/sgmv_shrink.py similarity index 100% rename from vllm/lora/ops/sgmv_shrink.py rename to vllm/lora/ops/triton/sgmv_shrink.py diff --git a/vllm/lora/ops/utils.py b/vllm/lora/ops/triton/utils.py similarity index 100% rename from vllm/lora/ops/utils.py rename to vllm/lora/ops/triton/utils.py diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py new file mode 100644 index 000000000000..4b13f221fd62 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -0,0 +1,14 @@ +from typing import final + +from .punica_gpu import PunicaWrapperGPU + + +@final +class PunicaWrapperCPU(PunicaWrapperGPU): + """ + PunicaWrapperCPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + It uses the punica ops in the same manner as the GPU implementation. + """ + pass diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index b2af29de129c..5b48c002e5f2 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -5,24 +5,16 @@ https://arxiv.org/abs/2310.18547 """ -from typing import Callable, Optional, Tuple, Union, final +from typing import Callable, Optional, Tuple, Union import torch -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.lora.ops.bgmv_expand import bgmv_expand - from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice - from vllm.lora.ops.bgmv_shrink import bgmv_shrink - from vllm.lora.ops.sgmv_expand import sgmv_expand - from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice - from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, sgmv_shrink) from .punica_base import PunicaWrapperBase -@final class PunicaWrapperGPU(PunicaWrapperBase): """ PunicaWrapperGPU is designed to manage and provide metadata for the punica diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index df6c1bdc7dd7..b7758688e52b 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -10,5 +10,10 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU print_info_once("Using PunicaWrapperGPU.") return PunicaWrapperGPU(*args, **kwargs) + elif current_platform.is_cpu(): + # Lazy import to avoid ImportError + from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU + print_info_once("Using PunicaWrapperCPU.") + return PunicaWrapperCPU(*args, **kwargs) else: raise NotImplementedError diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 420aaf8a1b4c..db8a03abf0cc 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -2,8 +2,8 @@ import weakref from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, - Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, + TypeVar, Union) import torch from torch import nn @@ -12,10 +12,14 @@ from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, @@ -49,6 +53,8 @@ class ModelInputForCPU(ModelRunnerInputBase): virtual_engine: Optional[int] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -57,6 +63,8 @@ def as_broadcastable_tensor_dict( "input_positions": self.input_positions, "token_type_ids": self.token_type_ids, "multi_modal_kwargs": self.multi_modal_kwargs, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -145,7 +153,11 @@ def __init__(self, or runner.cache_config.enable_prefix_caching) self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.enable_lora = self.runner.lora_config is not None self.input_data = ModelInputForCPUBuilder.ModelInputData( self.runner.model_config.uses_mrope) self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( @@ -184,15 +196,28 @@ def build(self) -> ModelInputForCPU: attn_metadata = self.att_metadata_builder.build( input_data.seq_lens, input_data.query_lens, -1, -1) - return self.model_input_cls( - input_tokens=input_tokens, - input_positions=input_positions, - token_type_ids=token_type_ids, - seq_lens=input_data.seq_lens, - query_lens=input_data.query_lens, - attn_metadata=attn_metadata, - multi_modal_kwargs=multi_modal_kwargs, - ) + is_prompt = (self.seq_group_metadata_list[0].is_prompt + if self.seq_group_metadata_list else None) + # LoRA data. + lora_requests = set() + lora_mapping = None + if self.enable_lora: + lora_requests = set(seq.lora_request + for seq in self.seq_group_metadata_list + if seq.lora_request is not None) + + lora_mapping = self._prepare_lora_input( + self.seq_group_metadata_list, is_prompt) + + return self.model_input_cls(input_tokens=input_tokens, + input_positions=input_positions, + token_type_ids=token_type_ids, + seq_lens=input_data.seq_lens, + query_lens=input_data.query_lens, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + lora_mapping=lora_mapping, + lora_requests=lora_requests) def _build_input_data(self): for seq_group_metadata in self.seq_group_metadata_list: @@ -383,6 +408,24 @@ def _compute_multi_modal_input(self, self.input_data.multi_modal_placeholder_maps[modality].extend( placeholder_map) + def _prepare_lora_input( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + is_prefill: bool) -> LoRAMapping: + index_mapping = [] + prompt_mapping = [] + for seq in seq_group_metadata_list: + lora_id = seq.lora_int_id + query_len = seq.token_chunk_size + + index_mapping += [lora_id] * query_len + prompt_mapping += [lora_id] * ( + query_len if seq.sampling_params + and seq.sampling_params.prompt_logprobs is not None else 1) + + return LoRAMapping(index_mapping=tuple(index_mapping), + prompt_mapping=tuple(prompt_mapping), + is_prefill=is_prefill) + class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): """ @@ -433,10 +476,41 @@ def __init__( # Lazy initialization. self.model: nn.Module # Set after init_Model + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None def load_model(self) -> None: self.model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = self.model.config.max_position_embeddings + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + def _prepare_model_input_tensors( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -461,6 +535,37 @@ def sampler(self): def vocab_size(self) -> int: return self.model_config.get_vocab_size() + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() + class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( @@ -517,6 +622,12 @@ def execute_model( raise ValueError( "CPU worker does not support multi-step execution.") + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + model_executable = self.model multimodal_kwargs = {} diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 4fad1a3f4cae..cb78d3282824 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Set, Tuple, Type import torch import torch.distributed @@ -11,14 +11,14 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoraNotSupportedWorkerBase, WorkerBase, +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) logger = init_logger(__name__) @@ -111,7 +111,7 @@ def get_cache_block_size( return dtype_size * total -class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): +class CPUWorker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a CPU socket. Each worker is associated with a single CPU socket. The worker is @@ -266,6 +266,18 @@ def initialize_cache(self, num_gpu_blocks: int, # Initialize the cache. self._init_cache_engine() + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: """Raise errors if the num_cpu_blocks is invalid. """ From 3545ac167d1598bebeee0de13f4b9e67c19f4027 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 11 Dec 2024 14:58:19 +0000 Subject: [PATCH 02/23] Readded final Signed-off-by: Akshat Tripathi --- tests/lora/test_layers.py | 2 +- vllm/lora/punica_wrapper/punica_gpu.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 2f5dddf8c946..f8702e51fdb0 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -208,7 +208,7 @@ def check_punica_wrapper(punica_wrapper) -> bool: return type(punica_wrapper) is PunicaWrapperGPU elif current_platform.is_cpu(): from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU - + return type(punica_wrapper) is PunicaWrapperCPU else: return False diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 5b48c002e5f2..59b2ab329998 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -5,7 +5,7 @@ https://arxiv.org/abs/2310.18547 """ -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union, final import torch @@ -14,7 +14,7 @@ from .punica_base import PunicaWrapperBase - +@final class PunicaWrapperGPU(PunicaWrapperBase): """ PunicaWrapperGPU is designed to manage and provide metadata for the punica From 2b3c650a009f2d60428d62ba9353d47f22e32cb6 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 11 Dec 2024 15:27:41 +0000 Subject: [PATCH 03/23] Decoupled PunicaWrapperCPU from PunicaWrapperGPU Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_cpu.py | 338 ++++++++++++++++++++++++- 1 file changed, 334 insertions(+), 4 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index 4b13f221fd62..9906f3a88d50 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -1,14 +1,344 @@ from typing import final -from .punica_gpu import PunicaWrapperGPU +import torch + +from vllm.lora.ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, sgmv_shrink) + +from .punica_base import PunicaWrapperBase @final -class PunicaWrapperCPU(PunicaWrapperGPU): +class PunicaWrapperCPU(PunicaWrapperBase): """ PunicaWrapperCPU is designed to manage and provide metadata for the punica kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the pytorch punica ops. - It uses the punica ops in the same manner as the GPU implementation. """ - pass + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + def _shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def _shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def _expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_input, + ) + + def _expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) + + def _expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_input, + ) + + def _expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_input) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) + shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_input=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_input (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_input=add_input, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_input: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_input (bool): Default to True. + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_input) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_input=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + y = y.view_as(y_org) From cc0bd6ccf7b194fe1a84450d779cd1ab2b01b1e3 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 11 Dec 2024 15:29:16 +0000 Subject: [PATCH 04/23] Fixed typing Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_cpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index 9906f3a88d50..854e26ebe6a3 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -1,4 +1,4 @@ -from typing import final +from typing import Callable, Optional, Tuple, Union, final import torch From dc9091a88936189f044cb3fc23f5802c2041d637 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 12 Dec 2024 10:59:45 +0000 Subject: [PATCH 05/23] Renamed lora op directories Signed-off-by: Akshat Tripathi --- vllm/lora/ops/__init__.py | 25 +++++++++++-------- .../ops/{default => torch_ops}/__init__.py | 0 .../ops/{default => torch_ops}/lora_ops.py | 0 .../ops/{triton => triton_ops}/__init__.py | 0 .../ops/{triton => triton_ops}/bgmv_expand.py | 0 .../bgmv_expand_slice.py | 0 .../ops/{triton => triton_ops}/bgmv_shrink.py | 0 .../ops/{triton => triton_ops}/sgmv_expand.py | 0 .../sgmv_expand_slice.py | 0 .../ops/{triton => triton_ops}/sgmv_shrink.py | 0 vllm/lora/ops/{triton => triton_ops}/utils.py | 0 vllm/lora/punica_wrapper/punica_cpu.py | 2 +- vllm/lora/punica_wrapper/punica_gpu.py | 1 + 13 files changed, 16 insertions(+), 12 deletions(-) rename vllm/lora/ops/{default => torch_ops}/__init__.py (100%) rename vllm/lora/ops/{default => torch_ops}/lora_ops.py (100%) rename vllm/lora/ops/{triton => triton_ops}/__init__.py (100%) rename vllm/lora/ops/{triton => triton_ops}/bgmv_expand.py (100%) rename vllm/lora/ops/{triton => triton_ops}/bgmv_expand_slice.py (100%) rename vllm/lora/ops/{triton => triton_ops}/bgmv_shrink.py (100%) rename vllm/lora/ops/{triton => triton_ops}/sgmv_expand.py (100%) rename vllm/lora/ops/{triton => triton_ops}/sgmv_expand_slice.py (100%) rename vllm/lora/ops/{triton => triton_ops}/sgmv_shrink.py (100%) rename vllm/lora/ops/{triton => triton_ops}/utils.py (100%) diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py index 7a13eabeb607..1b714246e6e9 100644 --- a/vllm/lora/ops/__init__.py +++ b/vllm/lora/ops/__init__.py @@ -1,16 +1,19 @@ +from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON -if HAS_TRITON: - from vllm.lora.ops.triton.bgmv_expand import bgmv_expand - from vllm.lora.ops.triton.bgmv_expand_slice import bgmv_expand_slice - from vllm.lora.ops.triton.bgmv_shrink import bgmv_shrink - from vllm.lora.ops.triton.sgmv_expand import sgmv_expand - from vllm.lora.ops.triton.sgmv_expand_slice import sgmv_expand_slice - from vllm.lora.ops.triton.sgmv_shrink import sgmv_shrink -else: - from vllm.lora.ops.default.lora_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) +if HAS_TRITON and current_platform.is_cuda_alike(): + from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand + from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand + from vllm.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink +elif current_platform.is_cpu(): + from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand, + bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, + sgmv_shrink) __all__ = [ "bgmv_expand", "bgmv_expand_slice", "bgmv_shrink", "sgmv_expand", diff --git a/vllm/lora/ops/default/__init__.py b/vllm/lora/ops/torch_ops/__init__.py similarity index 100% rename from vllm/lora/ops/default/__init__.py rename to vllm/lora/ops/torch_ops/__init__.py diff --git a/vllm/lora/ops/default/lora_ops.py b/vllm/lora/ops/torch_ops/lora_ops.py similarity index 100% rename from vllm/lora/ops/default/lora_ops.py rename to vllm/lora/ops/torch_ops/lora_ops.py diff --git a/vllm/lora/ops/triton/__init__.py b/vllm/lora/ops/triton_ops/__init__.py similarity index 100% rename from vllm/lora/ops/triton/__init__.py rename to vllm/lora/ops/triton_ops/__init__.py diff --git a/vllm/lora/ops/triton/bgmv_expand.py b/vllm/lora/ops/triton_ops/bgmv_expand.py similarity index 100% rename from vllm/lora/ops/triton/bgmv_expand.py rename to vllm/lora/ops/triton_ops/bgmv_expand.py diff --git a/vllm/lora/ops/triton/bgmv_expand_slice.py b/vllm/lora/ops/triton_ops/bgmv_expand_slice.py similarity index 100% rename from vllm/lora/ops/triton/bgmv_expand_slice.py rename to vllm/lora/ops/triton_ops/bgmv_expand_slice.py diff --git a/vllm/lora/ops/triton/bgmv_shrink.py b/vllm/lora/ops/triton_ops/bgmv_shrink.py similarity index 100% rename from vllm/lora/ops/triton/bgmv_shrink.py rename to vllm/lora/ops/triton_ops/bgmv_shrink.py diff --git a/vllm/lora/ops/triton/sgmv_expand.py b/vllm/lora/ops/triton_ops/sgmv_expand.py similarity index 100% rename from vllm/lora/ops/triton/sgmv_expand.py rename to vllm/lora/ops/triton_ops/sgmv_expand.py diff --git a/vllm/lora/ops/triton/sgmv_expand_slice.py b/vllm/lora/ops/triton_ops/sgmv_expand_slice.py similarity index 100% rename from vllm/lora/ops/triton/sgmv_expand_slice.py rename to vllm/lora/ops/triton_ops/sgmv_expand_slice.py diff --git a/vllm/lora/ops/triton/sgmv_shrink.py b/vllm/lora/ops/triton_ops/sgmv_shrink.py similarity index 100% rename from vllm/lora/ops/triton/sgmv_shrink.py rename to vllm/lora/ops/triton_ops/sgmv_shrink.py diff --git a/vllm/lora/ops/triton/utils.py b/vllm/lora/ops/triton_ops/utils.py similarity index 100% rename from vllm/lora/ops/triton/utils.py rename to vllm/lora/ops/triton_ops/utils.py diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index 854e26ebe6a3..02dd8693f8f5 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -15,7 +15,7 @@ class PunicaWrapperCPU(PunicaWrapperBase): kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the pytorch punica ops. """ - + def __init__(self, max_num_batched_tokens: int, max_batches: int, device: Union[torch.device, str], **kwargs): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 59b2ab329998..fe7c7c538918 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -14,6 +14,7 @@ from .punica_base import PunicaWrapperBase + @final class PunicaWrapperGPU(PunicaWrapperBase): """ From 410b746ce5fc9eb0c5c3c878a01c58a2d69b4974 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Thu, 12 Dec 2024 11:02:08 +0000 Subject: [PATCH 06/23] Fixed tests Signed-off-by: Akshat Tripathi --- tests/lora/test_punica_variation.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index f280c758c6fd..3e91e021bde6 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -10,12 +10,12 @@ # Enable custom op register if we're using custom ops if HAS_TRITON: - import vllm.lora.ops.triton.bgmv_expand - import vllm.lora.ops.triton.bgmv_expand_slice - import vllm.lora.ops.triton.bgmv_shrink - import vllm.lora.ops.triton.sgmv_expand - import vllm.lora.ops.triton.sgmv_expand_slice - import vllm.lora.ops.triton.sgmv_shrink # noqa: F401 + import vllm.lora.ops.triton_ops.bgmv_expand + import vllm.lora.ops.triton_ops.bgmv_expand_slice + import vllm.lora.ops.triton_ops.bgmv_shrink + import vllm.lora.ops.triton_ops.sgmv_expand + import vllm.lora.ops.triton_ops.sgmv_expand_slice + import vllm.lora.ops.triton_ops.sgmv_shrink # noqa: F401 # Unlike test_punica_sizes.py, we directly utilize custom op for # testing, which verifies the correct registration of these ops. @@ -26,7 +26,7 @@ sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice sgmv_shrink = torch.ops.vllm.sgmv_shrink else: - from vllm.lora.ops.default.lora_ops import ( # type: ignore + from vllm.lora.ops.torch_ops.lora_ops import ( # type: ignore bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink) From 5dfcf627b3fe4894026227605f26bc6fa7a11cf6 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 13 Dec 2024 10:40:19 +0000 Subject: [PATCH 07/23] Removed one-to-one correspondence between triton and torch ops Signed-off-by: Akshat Tripathi --- tests/lora/test_punica_sizes.py | 17 +++++++++++++++-- vllm/lora/ops/__init__.py | 21 --------------------- vllm/lora/punica_wrapper/punica_cpu.py | 5 +++-- vllm/lora/punica_wrapper/punica_gpu.py | 11 +++++++++-- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index af31748f8cbd..f1391538e747 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -7,9 +7,22 @@ import pytest import torch -from vllm.lora.ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, - sgmv_expand, sgmv_expand_slice, sgmv_shrink) from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON + +# Enable custom op register if we're using custom ops +if HAS_TRITON: + from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand + from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand + from vllm.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink + +elif current_platform.is_cpu(): + from vllm.lora.ops.torch_ops.lora_ops import ( # type: ignore + bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) from .utils import (generate_data, generate_data_for_expand_nslices, ref_torch_groupgemm) diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py index 1b714246e6e9..e69de29bb2d1 100644 --- a/vllm/lora/ops/__init__.py +++ b/vllm/lora/ops/__init__.py @@ -1,21 +0,0 @@ -from vllm.platforms import current_platform -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON and current_platform.is_cuda_alike(): - from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand - from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice - from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink - from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand - from vllm.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice - from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink -elif current_platform.is_cpu(): - from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand, - bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, - sgmv_shrink) - -__all__ = [ - "bgmv_expand", "bgmv_expand_slice", "bgmv_shrink", "sgmv_expand", - "sgmv_expand_slice", "sgmv_shrink" -] diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index 02dd8693f8f5..8e00912beeae 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -2,8 +2,9 @@ import torch -from vllm.lora.ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, - sgmv_expand, sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) from .punica_base import PunicaWrapperBase diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index fe7c7c538918..3909661713df 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -9,8 +9,15 @@ import torch -from vllm.lora.ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, - sgmv_expand, sgmv_expand_slice, sgmv_shrink) +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand + from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand + from vllm.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink from .punica_base import PunicaWrapperBase From 651dc049d3f678762581885c312b7a9d3a081905 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 13 Dec 2024 10:50:18 +0000 Subject: [PATCH 08/23] Fixed mypy error Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_cpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index 8e00912beeae..58341241c744 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -2,9 +2,9 @@ import torch -from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) from .punica_base import PunicaWrapperBase From d1026bbc6811a7a42a7fdfc00eaf42023a28da69 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 13 Dec 2024 10:57:46 +0000 Subject: [PATCH 09/23] Removed redundant optionals Signed-off-by: Akshat Tripathi --- vllm/lora/punica_wrapper/punica_cpu.py | 12 ++++++------ vllm/lora/punica_wrapper/punica_gpu.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index 58341241c744..d6f4fa36561d 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -81,8 +81,8 @@ def _expand_slice_prefill( y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], + y_offset: int, + y_slice_size: int, add_input: bool, ): #No LoRA request, so return directly @@ -103,8 +103,8 @@ def _expand_slice_decode( y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], + y_offset: int, + y_slice_size: int, add_input: bool, ): bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, @@ -115,8 +115,8 @@ def _apply_expand( y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], + y_offset: int, + y_slice_size: int, add_input: bool = True, ): """ diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 3909661713df..cc962d18733e 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -94,8 +94,8 @@ def _expand_slice_prefill( y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], + y_offset: int, + y_slice_size: int, add_input: bool, ): #No LoRA request, so return directly @@ -116,8 +116,8 @@ def _expand_slice_decode( y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], + y_offset: int, + y_slice_size: int, add_input: bool, ): bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, @@ -128,8 +128,8 @@ def _apply_expand( y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], + y_offset: int, + y_slice_size: int, add_input: bool = True, ): """ From 41c518ff8b81ebe2c70667d30e7289762c62144b Mon Sep 17 00:00:00 2001 From: Oleg Mosalov Date: Tue, 17 Dec 2024 09:58:53 +0100 Subject: [PATCH 10/23] Renamed add_input to add_inputs in punica_cpu.py. Signed-off-by: Oleg Mosalov --- vllm/lora/punica_wrapper/punica_cpu.py | 34 +++++++++++++------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index d6f4fa36561d..4235e7bf4485 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -54,7 +54,7 @@ def _expand_prefill( y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, - add_input: bool, + add_inputs: bool, ): #No LoRA request, so return directly if self.no_lora: @@ -64,7 +64,7 @@ def _expand_prefill( w_t_all, y, *self.prefill_metadata, - add_input, + add_inputs, ) def _expand_decode( @@ -72,9 +72,9 @@ def _expand_decode( y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, - add_input: bool, + add_inputs: bool, ): - bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) def _expand_slice_prefill( self, @@ -83,7 +83,7 @@ def _expand_slice_prefill( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, - add_input: bool, + add_inputs: bool, ): #No LoRA request, so return directly if self.no_lora: @@ -95,7 +95,7 @@ def _expand_slice_prefill( *self.prefill_metadata, y_offset, y_slice_size, - add_input, + add_inputs, ) def _expand_slice_decode( @@ -105,10 +105,10 @@ def _expand_slice_decode( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, - add_input: bool, + add_inputs: bool, ): bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, - y_slice_size, add_input) + y_slice_size, add_inputs) def _apply_expand( self, @@ -117,7 +117,7 @@ def _apply_expand( w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, - add_input: bool = True, + add_inputs: bool = True, ): """ Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` @@ -128,7 +128,7 @@ def _apply_expand( expand_slice_fun: Callable = (self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode) - expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float): @@ -181,7 +181,7 @@ def add_expand(self, lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], output_slices: Tuple[int, ...], offset_start: int = 0, - add_input=True, + add_inputs=True, **kwargs) -> None: """ Performs GEMM and bias addition for multiple slices of lora_b. @@ -200,7 +200,7 @@ def add_expand(self, lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): bias's weight output_slices (Tuple[int, ...]): Every slice's size - add_input (bool): Defaults to True. + add_inputs (bool): Defaults to True. """ y_org = y y = y.view(-1, y.shape[-1]) @@ -215,7 +215,7 @@ def add_expand(self, lora_b_stacked[slice_idx], offset_left, output_slices[slice_idx], - add_input=add_input, + add_inputs=add_inputs, ) offset_left += output_slices[slice_idx] y = y.view_as(y_org) @@ -224,7 +224,7 @@ def add_lora_embedding(self, y: torch.Tensor, x: torch.Tensor, lora_b_stacked: torch.Tensor, - add_input: bool = True, + add_inputs: bool = True, **kwargs) -> None: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. @@ -236,13 +236,13 @@ def add_lora_embedding(self, y (torch.Tensor): Output tensor. x (torch.Tensor): Input tensor. lora_b_stacked (torch.Tensor): lora_b's weights. - add_input (bool): Default to True. + add_inputs (bool): Default to True. """ # Embedding layer only need expand op expand_fun: Callable = (self._expand_prefill if self.is_prefill else self._expand_decode) - expand_fun(y, x, lora_b_stacked, add_input) + expand_fun(y, x, lora_b_stacked, add_inputs) def add_lora_linear(self, y: torch.Tensor, @@ -298,7 +298,7 @@ def add_lora_linear(self, lora_b_stacked, None, output_slices, - add_input=True, + add_inputs=True, **kwargs) def add_lora_logits(self, From d212fc9d0b339945dc75dd02688c3e347bc0e72f Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 24 Dec 2024 02:30:12 +0000 Subject: [PATCH 11/23] Optimize kernel test Signed-off-by: Jee Jee Li --- tests/lora/test_layers.py | 8 +- tests/lora/test_lora_manager.py | 7 +- tests/lora/test_punica_sizes.py | 392 ------------------------ tests/lora/test_punica_variation.py | 316 ------------------- vllm/lora/ops/torch_ops/__init__.py | 13 + vllm/lora/ops/triton_ops/__init__.py | 15 + vllm/lora/ops/triton_ops/sgmv_shrink.py | 1 - vllm/lora/punica_wrapper/punica_cpu.py | 19 +- vllm/lora/punica_wrapper/punica_gpu.py | 12 +- 9 files changed, 48 insertions(+), 735 deletions(-) delete mode 100644 tests/lora/test_punica_sizes.py delete mode 100644 tests/lora/test_punica_variation.py diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index f8702e51fdb0..08a589d7ee29 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -53,13 +53,9 @@ not (current_platform.is_cuda_alike() or current_platform.is_cpu()), reason="Backend not supported") -CUDA_DEVICES = [ +DEVICES = ([ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - -CPU_DEVICES = ["cpu"] - -DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES +] if current_platform.is_cuda_alike() else ["cpu"]) #For GPU, we will launch different triton kernels between the prefill and decode # stages, so we need to verify this. prefill stage(True) or decode stage(False) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 927d8392452b..d3422810ae8a 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -28,12 +28,9 @@ EMBEDDING_PADDING_MODULES = ["lm_head"] -CUDA_DEVICES = [ +DEVICES = ([ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] -CPU_DEVICES = ["cpu"] - -DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES +] if current_platform.is_cuda_alike() else ["cpu"]) def test_peft_helper(sql_lora_files): diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py deleted file mode 100644 index f1391538e747..000000000000 --- a/tests/lora/test_punica_sizes.py +++ /dev/null @@ -1,392 +0,0 @@ -""" -This script is mainly used to tests various hidden_sizes. We have collected the -hidden_sizes included in the LoRA models currently supported by vLLM. It tests -whether the corresponding Triton kernel can run normally when tensor parallelism -is set to [1, 2, 4, 8, 16, 32, 64]. -""" -import pytest -import torch - -from vllm.platforms import current_platform -from vllm.triton_utils import HAS_TRITON - -# Enable custom op register if we're using custom ops -if HAS_TRITON: - from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand - from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice - from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink - from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand - from vllm.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice - from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink - -elif current_platform.is_cpu(): - from vllm.lora.ops.torch_ops.lora_ops import ( # type: ignore - bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) - -from .utils import (generate_data, generate_data_for_expand_nslices, - ref_torch_groupgemm) - -HIDDEN_SIZES = [ - 128, - 256, - 512, - 896, - 1024, - 1152, - 1216, - 1280, - 1536, - 1664, - 2048, - 2240, - 2304, - 2368, - 2432, - 2560, - 2752, - 3072, - 3328, - 3456, - 3584, - 3712, - 4096, - 4480, - 4608, - 4736, - 4864, - 5120, - 5504, - 5632, - 5888, - 6144, - 6400, - 6848, - 6912, - 7168, - 7424, - 8192, - 8960, - 9216, - 9472, - 10240, - 11008, - 11264, - 13824, - 14336, - 14784, - 14848, - 15360, - 18944, - 22016, - 22528, - 24576, - 27392, - 27648, - 29568, - 29696, - 32000, - 32256, - 32512, - 32768, - 33024, - 36864, - 43264, - 49152, - 49408, - 60544, - 60672, - 64000, - 64256, - 102400, - 102656, - 128000, - 128256, -] -#The size of TP -divisibility = [1, 2, 8, 16, 64] - -all_hidden_size = [] -for div in divisibility: - for hidden_size in HIDDEN_SIZES: - all_hidden_size.append(hidden_size // div) - -HIDDEN_SIZES = list(set(all_hidden_size)) - -BATCHES = [4] -NUM_LORA = [4] -DTYPES = [torch.float16, torch.bfloat16] -MAX_RANKS = [32] -SCALES = [0.5] -SEED = [0] - -CUDA_DEVICES = ["cuda:0"] -CPU_DEVICES = ["cpu"] -DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES - - -def assert_close(a, b): - rtol, atol = { - torch.float16: (6e-2, 6e-2), - torch.bfloat16: (6e-2, 6e-2), - torch.float32: (1e-2, 1e-2), - }[a.dtype] - torch.testing.assert_close(a, b, rtol=rtol, atol=atol) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("scaling", SCALES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_sgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - scaling: float, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 128 - ( - inputs_tensor, - lora_weights, - our_out_tensor, - ref_out_tensor, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - op_type, - device, - ) - max_seq_length = seq_len_tensor.max() - token_nums = seq_len_tensor.sum().item() - if isinstance(max_seq_length, tuple): - max_seq_length = max_seq_length[0].item() - else: - max_seq_length = max_seq_length.item() - if op_type == "shrink": - sgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - else: - sgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - add_inputs=True, - ) - ref_torch_groupgemm( - ref_out_tensor, - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - scaling if op_type == "shrink" else 1.0, - op_type, - ) - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) - assert_close(our_out_tensor, ref_out_tensor) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("scaling", SCALES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_bgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - scaling: float, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 1 - ( - inputs_tensor, - lora_weights, - our_out_tensor, - ref_out_tensor, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - op_type, - device, - ) - if op_type == "shrink": - bgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - scaling, - ) - else: - bgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - add_inputs=True, - ) - ref_torch_groupgemm( - ref_out_tensor, - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - scaling if op_type == "shrink" else 1.0, - op_type, - ) - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) - assert_close(our_out_tensor, ref_out_tensor) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("nslices", [2, 3]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_expand_nslices( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - nslices: int, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 128 if op_type == "sgmv" else 1 - ( - inputs_tensor, - lora_weights_lst, - our_outputs, - ref_outputs, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data_for_expand_nslices( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - nslices, - device, - ) - max_seq_length = seq_len_tensor.max() - token_nums = seq_len_tensor.sum().item() - if isinstance(max_seq_length, tuple): - max_seq_length = max_seq_length[0].item() - else: - max_seq_length = max_seq_length.item() - slice_offset = 0 - for index in range(nslices): - lora_weights = lora_weights_lst[index] - if op_type == "sgmv": - sgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - else: - - bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) - ref_torch_groupgemm( - ref_outputs[:, slice_offset:slice_offset + hidden_size], - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - 1.0, - op_type="expand", - ) - - slice_offset += hidden_size - assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py deleted file mode 100644 index 3e91e021bde6..000000000000 --- a/tests/lora/test_punica_variation.py +++ /dev/null @@ -1,316 +0,0 @@ -""" -This script is mainly used to test whether trtion kernels can run normally -under different conditions, including various batches, numbers of LoRA , and -maximum ranks. -""" -import pytest -import torch - -from vllm.triton_utils import HAS_TRITON - -# Enable custom op register if we're using custom ops -if HAS_TRITON: - import vllm.lora.ops.triton_ops.bgmv_expand - import vllm.lora.ops.triton_ops.bgmv_expand_slice - import vllm.lora.ops.triton_ops.bgmv_shrink - import vllm.lora.ops.triton_ops.sgmv_expand - import vllm.lora.ops.triton_ops.sgmv_expand_slice - import vllm.lora.ops.triton_ops.sgmv_shrink # noqa: F401 - - # Unlike test_punica_sizes.py, we directly utilize custom op for - # testing, which verifies the correct registration of these ops. - bgmv_expand = torch.ops.vllm.bgmv_expand - bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice - bgmv_shrink = torch.ops.vllm.bgmv_shrink - sgmv_expand = torch.ops.vllm.sgmv_expand - sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice - sgmv_shrink = torch.ops.vllm.sgmv_shrink -else: - from vllm.lora.ops.torch_ops.lora_ops import ( # type: ignore - bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) - -from vllm.platforms import current_platform - -from .utils import (generate_data, generate_data_for_expand_nslices, - ref_torch_groupgemm) - -HIDDEN_SIZES = [4097] - -BATCHES = [1, 4, 16, 32] -NUM_LORA = [1, 8, 32, 128] -DTYPES = [torch.float16, torch.bfloat16] -MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256] -SCALES = [0.5] -SEED = [0] - -CUDA_DEVICES = ["cuda:0"] -CPU_DEVICES = ["cpu"] -DEVICES = CUDA_DEVICES if current_platform.is_cuda_alike() else CPU_DEVICES - - -def assert_close(a, b): - rtol, atol = { - torch.float16: (6e-2, 6e-2), - torch.bfloat16: (6e-2, 6e-2), - torch.float32: (1e-2, 1e-2), - }[a.dtype] - torch.testing.assert_close(a, b, rtol=rtol, atol=atol) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("scaling", SCALES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_sgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - scaling: float, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 128 - ( - inputs_tensor, - lora_weights, - our_out_tensor, - ref_out_tensor, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - op_type, - device, - ) - max_seq_length = seq_len_tensor.max() - token_nums = seq_len_tensor.sum().item() - if isinstance(max_seq_length, tuple): - max_seq_length = max_seq_length[0].item() - else: - max_seq_length = max_seq_length.item() - if op_type == "shrink": - sgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - else: - sgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - add_inputs=True, - ) - ref_torch_groupgemm( - ref_out_tensor, - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - scaling if op_type == "shrink" else 1.0, - op_type, - ) - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) - assert_close(our_out_tensor, ref_out_tensor) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("scaling", SCALES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_bgmv( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - scaling: float, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 1 - ( - inputs_tensor, - lora_weights, - our_out_tensor, - ref_out_tensor, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - op_type, - device, - ) - if op_type == "shrink": - bgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - scaling, - ) - else: - - bgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - indices, - add_inputs=True, - ) - ref_torch_groupgemm( - ref_out_tensor, - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - scaling if op_type == "shrink" else 1.0, - op_type, - ) - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) - assert_close(our_out_tensor, ref_out_tensor) - - -@pytest.mark.parametrize("batches", BATCHES) -@pytest.mark.parametrize("num_loras", NUM_LORA) -@pytest.mark.parametrize("rank", MAX_RANKS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("nslices", [2, 3]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) -@pytest.mark.parametrize("seed", SEED) -@pytest.mark.parametrize("device", DEVICES) -def test_punica_expand_nslices( - batches: int, - num_loras: int, - rank: int, - hidden_size: int, - nslices: int, - dtype: torch.dtype, - op_type: str, - seed: int, - device: str, -): - torch.set_default_device(device) - current_platform.seed_everything(seed) - - seq_length = 128 if op_type == "sgmv" else 1 - ( - inputs_tensor, - lora_weights_lst, - our_outputs, - ref_outputs, - b_seq_start_loc, - lora_indices_tensor, - seq_len_tensor, - indices, - ) = generate_data_for_expand_nslices( - batches, - hidden_size, - num_loras, - rank, - seq_length, - dtype, - nslices, - device, - ) - max_seq_length = seq_len_tensor.max() - token_nums = seq_len_tensor.sum().item() - if isinstance(max_seq_length, tuple): - max_seq_length = max_seq_length[0].item() - else: - max_seq_length = max_seq_length.item() - slice_offset = 0 - for index in range(nslices): - lora_weights = lora_weights_lst[index] - if op_type == "sgmv": - sgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - else: - bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) - ref_torch_groupgemm( - ref_outputs[:, slice_offset:slice_offset + hidden_size], - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - 1.0, - op_type="expand", - ) - - slice_offset += hidden_size - assert_close(our_outputs, ref_outputs) diff --git a/vllm/lora/ops/torch_ops/__init__.py b/vllm/lora/ops/torch_ops/__init__.py index e69de29bb2d1..9c9159b95f30 100644 --- a/vllm/lora/ops/torch_ops/__init__.py +++ b/vllm/lora/ops/torch_ops/__init__.py @@ -0,0 +1,13 @@ +from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401 +from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink, + sgmv_expand, sgmv_expand_slice, + sgmv_shrink) + +__all__ = [ + "bgmv_expand", + "bgmv_expand_slice", + "bgmv_shrink", + "sgmv_expand", + "sgmv_expand_slice", + "sgmv_shrink", +] diff --git a/vllm/lora/ops/triton_ops/__init__.py b/vllm/lora/ops/triton_ops/__init__.py index e69de29bb2d1..70d08c682713 100644 --- a/vllm/lora/ops/triton_ops/__init__.py +++ b/vllm/lora/ops/triton_ops/__init__.py @@ -0,0 +1,15 @@ +from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand +from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice +from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink +from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand +from vllm.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice +from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink # noqa: F401 + +__all__ = [ + "bgmv_expand", + "bgmv_expand_slice", + "bgmv_shrink", + "sgmv_expand", + "sgmv_expand_slice", + "sgmv_shrink", +] diff --git a/vllm/lora/ops/triton_ops/sgmv_shrink.py b/vllm/lora/ops/triton_ops/sgmv_shrink.py index 37d1dc84eebc..afe04193018d 100644 --- a/vllm/lora/ops/triton_ops/sgmv_shrink.py +++ b/vllm/lora/ops/triton_ops/sgmv_shrink.py @@ -165,7 +165,6 @@ def _sgmv_shrink( SPLIT_K, batches, ) - _sgmv_shrink_kernel[grid]( inputs, lora_a_weights, diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index 4235e7bf4485..b9ae3e07492c 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -1,15 +1,16 @@ -from typing import Callable, Optional, Tuple, Union, final +from typing import Callable, Optional, Tuple, Union import torch -from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, - bgmv_shrink, sgmv_expand, - sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) from .punica_base import PunicaWrapperBase -@final +# The platforms that are compatible with the PyTorch-native implementation can +# inherit this class class PunicaWrapperCPU(PunicaWrapperBase): """ PunicaWrapperCPU is designed to manage and provide metadata for the punica @@ -286,8 +287,8 @@ def add_lora_linear(self, if buffer is None: r = lora_b_stacked[0].size(-1) - # We set the buffer to be float32 by default ,refer to: - # https://github.com/triton-lang/triton/issues/1387 + # We set the buffer to be float32 by default, consistent with the + # triton op buffer = tuple( torch.zeros( (x.size(0), r), dtype=torch.float32, device=x.device) @@ -330,8 +331,8 @@ def add_lora_logits(self, x = x.view(-1, x.shape[-1]) r = lora_b_stacked.size(-1) if buffer is None: - # We set the buffer to be float32 by default ,refer to: - # https://github.com/triton-lang/triton/issues/1387 + # We set the buffer to be float32 by default, consistent with the + # triton op buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 785976c0cb8e..f69a0bca7f09 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -12,12 +12,12 @@ from vllm.triton_utils import HAS_TRITON if HAS_TRITON: - from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand - from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice - from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink - from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand - from vllm.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice - from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink + from vllm.lora.ops.triton_ops import bgmv_expand + from vllm.lora.ops.triton_ops import bgmv_expand_slice + from vllm.lora.ops.triton_ops import bgmv_shrink + from vllm.lora.ops.triton_ops import sgmv_expand + from vllm.lora.ops.triton_ops import sgmv_expand_slice + from vllm.lora.ops.triton_ops import sgmv_shrink from .punica_base import PunicaWrapperBase From 3a684b8a2a292386cb60a72831755a7f78e79d1f Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 30 Dec 2024 03:03:55 +0000 Subject: [PATCH 12/23] Modify test name Signed-off-by: Jee Jee Li --- tests/lora/test_punica_ops_sizes.py | 416 ++++++++++++++++++++++++ tests/lora/test_punica_ops_variation.py | 344 ++++++++++++++++++++ 2 files changed, 760 insertions(+) create mode 100644 tests/lora/test_punica_ops_sizes.py create mode 100644 tests/lora/test_punica_ops_variation.py diff --git a/tests/lora/test_punica_ops_sizes.py b/tests/lora/test_punica_ops_sizes.py new file mode 100644 index 000000000000..b2dbd5645c73 --- /dev/null +++ b/tests/lora/test_punica_ops_sizes.py @@ -0,0 +1,416 @@ +""" +This script is mainly used to tests various hidden_sizes. We have collected the +hidden_sizes included in the LoRA models currently supported by vLLM. It tests +whether the corresponding Triton kernel can run normally when tensor parallelism +is set to [1, 2, 4, 8, 16, 32, 64]. +""" +import pytest +import torch + +# Enable custom op register +import vllm.lora.ops.triton_ops # noqa: F401 +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) +from vllm.platforms import current_platform + +from .utils import generate_data, generate_data_for_expand_nslices + +HIDDEN_SIZES = [ + 128, + 256, + 512, + 896, + 1024, + 1152, + 1216, + 1280, + 1536, + 1664, + 2048, + 2240, + 2304, + 2368, + 2432, + 2560, + 2752, + 3072, + 3328, + 3456, + 3584, + 3712, + 4096, + 4480, + 4608, + 4736, + 4864, + 5120, + 5504, + 5632, + 5888, + 6144, + 6400, + 6848, + 6912, + 7168, + 7424, + 8192, + 8960, + 9216, + 9472, + 10240, + 11008, + 11264, + 13824, + 14336, + 14784, + 14848, + 15360, + 18944, + 22016, + 22528, + 24576, + 27392, + 27648, + 29568, + 29696, + 32000, + 32256, + 32512, + 32768, + 33024, + 36864, + 43264, + 49152, + 49408, + 60544, + 60672, + 64000, + 64256, + 102400, + 102656, + 128000, + 128256, +] +#The size of TP +divisibility = [1, 2, 8, 16, 64] + +all_hidden_size = [] +for div in divisibility: + for hidden_size in HIDDEN_SIZES: + all_hidden_size.append(hidden_size // div) + +HIDDEN_SIZES = list(set(all_hidden_size)) + +BATCHES = [4] +NUM_LORA = [4] +DTYPES = [torch.float16, torch.bfloat16] +MAX_RANKS = [32] +SCALES = [0.5] +SEED = [0] + +DEVICES = ["cuda:0"] + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", DEVICES) +def test_punica_sgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + seq_length = 128 + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + if op_type == "shrink": + # triton op + torch.ops.vllm.sgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + scaling, + ) + #torch op + sgmv_shrink( + inputs_tensor, + lora_weights, + ref_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + scaling, + ) + else: + torch.ops.vllm.sgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + add_inputs=True, + ) + sgmv_expand( + inputs_tensor, + lora_weights, + ref_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + add_inputs=True, + ) + + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", DEVICES) +def test_punica_bgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + seq_length = 1 + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + if op_type == "shrink": + torch.ops.vllm.bgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + scaling, + ) + + bgmv_shrink( + inputs_tensor, + lora_weights, + ref_out_tensor, + indices, + scaling, + ) + + else: + torch.ops.vllm.bgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + add_inputs=True, + ) + bgmv_expand( + inputs_tensor, + lora_weights, + ref_out_tensor, + indices, + add_inputs=True, + ) + + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", DEVICES) +def test_punica_expand_nslices( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + + torch.set_default_device(device) + current_platform.seed_everything(seed) + + seq_length = 128 if op_type == "sgmv" else 1 + ( + inputs_tensor, + lora_weights_lst, + our_outputs, + ref_outputs, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data_for_expand_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + nslices, + device, + ) + max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + if op_type == "sgmv": + torch.ops.vllm.sgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + slice_offset, + hidden_size, + add_inputs=True, + ) + sgmv_expand_slice( + inputs_tensor, + lora_weights, + ref_outputs, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + slice_offset, + hidden_size, + add_inputs=True, + ) + else: + + torch.ops.vllm.bgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) + bgmv_expand_slice( + inputs_tensor, + lora_weights, + ref_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) + + slice_offset += hidden_size + assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/test_punica_ops_variation.py b/tests/lora/test_punica_ops_variation.py new file mode 100644 index 000000000000..2f553ec47d63 --- /dev/null +++ b/tests/lora/test_punica_ops_variation.py @@ -0,0 +1,344 @@ +""" +This script is mainly used to test whether trtion kernels can run normally +under different conditions, including various batches, numbers of LoRA , and +maximum ranks. +""" +import pytest +import torch + +# Enable custom op register +import vllm.lora.ops.triton_ops # noqa: F401 +from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink, sgmv_expand, + sgmv_expand_slice, sgmv_shrink) +from vllm.platforms import current_platform + +from .utils import generate_data, generate_data_for_expand_nslices + +HIDDEN_SIZES = [4097] + +BATCHES = [1, 4, 16, 32] +NUM_LORA = [1, 8, 32, 128] +DTYPES = [torch.float16, torch.bfloat16] +MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256] +SCALES = [0.5] +SEED = [0] + +DEVICES = ["cuda:0"] + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", DEVICES) +def test_punica_sgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + seq_length = 128 + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + if op_type == "shrink": + # triton op + torch.ops.vllm.sgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + scaling, + ) + # torch op + sgmv_shrink( + inputs_tensor, + lora_weights, + ref_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + scaling, + ) + else: + torch.ops.vllm.sgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + add_inputs=True, + ) + sgmv_expand( + inputs_tensor, + lora_weights, + ref_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + add_inputs=True, + ) + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", DEVICES) +def test_punica_bgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + + torch.set_default_device(device) + current_platform.seed_everything(seed) + + seq_length = 1 + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + if op_type == "shrink": + torch.ops.vllm.bgmv_shrink( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + scaling, + ) + bgmv_shrink( + inputs_tensor, + lora_weights, + ref_out_tensor, + indices, + scaling, + ) + + # bgmv_expand = torch.ops.vllm.bgmv_expand + # bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice + # bgmv_shrink = torch.ops.vllm.bgmv_shrink + # sgmv_expand = torch.ops.vllm.sgmv_expand + # sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice + # sgmv_shrink = torch.ops.vllm.sgmv_shrink + else: + + torch.ops.vllm.bgmv_expand( + inputs_tensor, + lora_weights, + our_out_tensor, + indices, + add_inputs=True, + ) + bgmv_expand( + inputs_tensor, + lora_weights, + ref_out_tensor, + indices, + add_inputs=True, + ) + # ref_torch_groupgemm( + # ref_out_tensor, + # inputs_tensor, + # lora_weights, + # lora_indices_tensor, + # seq_len_tensor, + # batches, + # scaling if op_type == "shrink" else 1.0, + # op_type, + # ) + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", DEVICES) +def test_punica_expand_nslices( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + op_type: str, + seed: int, + device: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + seq_length = 128 if op_type == "sgmv" else 1 + ( + inputs_tensor, + lora_weights_lst, + our_outputs, + ref_outputs, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) = generate_data_for_expand_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + nslices, + device, + ) + max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() + if isinstance(max_seq_length, tuple): + max_seq_length = max_seq_length[0].item() + else: + max_seq_length = max_seq_length.item() + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + if op_type == "sgmv": + torch.ops.vllm.sgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + slice_offset, + hidden_size, + add_inputs=True, + ) + + sgmv_expand_slice( + inputs_tensor, + lora_weights, + ref_outputs, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + slice_offset, + hidden_size, + add_inputs=True, + ) + else: + torch.ops.vllm.bgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) + bgmv_expand_slice( + inputs_tensor, + lora_weights, + ref_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) + slice_offset += hidden_size + assert_close(our_outputs, ref_outputs) From cfa082ff5a3d8f527b5a9cb6fd434eb257bc05fe Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Wed, 8 Jan 2025 09:35:27 +0000 Subject: [PATCH 13/23] Removed assert Signed-off-by: Akshat Tripathi --- vllm/executor/cpu_executor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 5495bc50ede8..b9a6bee5720f 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -22,9 +22,6 @@ class CPUExecutor(ExecutorBase): def _init_executor(self) -> None: assert self.device_config.device_type == "cpu" - # Reminder: Please update docs/source/usage/compatibility_matrix.md - # If the feature combo become valid - assert self.lora_config is None, "cpu backend doesn't support LoRA" # # Environment variables for CPU executor From 21c3799ab78d5101f53e267e669a1f053ce0cd57 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 9 Jan 2025 06:37:37 +0000 Subject: [PATCH 14/23] Resolve test conflicts Signed-off-by: Jee Jee Li --- tests/lora/test_punica_ops_sizes.py | 216 +++++--------------- tests/lora/test_punica_ops_variation.py | 252 ++++++------------------ tests/lora/utils.py | 27 --- vllm/lora/ops/triton_ops/__init__.py | 2 - 4 files changed, 109 insertions(+), 388 deletions(-) diff --git a/tests/lora/test_punica_ops_sizes.py b/tests/lora/test_punica_ops_sizes.py index 00f073f6471b..47f652392650 100644 --- a/tests/lora/test_punica_ops_sizes.py +++ b/tests/lora/test_punica_ops_sizes.py @@ -9,28 +9,15 @@ import pytest import torch -<<<<<<< HEAD:tests/lora/test_punica_ops_sizes.py -# Enable custom op register import vllm.lora.ops.triton_ops # noqa: F401 +from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.platforms import current_platform from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink) -from vllm.platforms import current_platform - -from .utils import generate_data, generate_data_for_expand_nslices -======= -from vllm.lora.ops.bgmv_expand import bgmv_expand -from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice -from vllm.lora.ops.bgmv_shrink import bgmv_shrink -from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_shrink import sgmv_shrink -from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT -from vllm.platforms import current_platform - from .utils import (assert_close, generate_data, generate_data_for_expand_nslices, - generate_data_for_nslices, ref_torch_groupgemm) ->>>>>>> main:tests/lora/test_punica_sizes.py + generate_data_for_nslices) HIDDEN_SIZES = [ 128, @@ -124,8 +111,7 @@ MAX_RANKS = [32] SCALES = [0.5] SEED = [0] - -DEVICES = ["cuda:0"] +DEVICES = [f"cuda:{0}"] _dict_lock = Lock() @@ -137,7 +123,7 @@ @pytest.mark.parametrize("scaling", SCALES) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( @@ -183,66 +169,10 @@ def test_punica_sgmv( else: max_seq_length = max_seq_length.item() if op_type == "shrink": -<<<<<<< HEAD:tests/lora/test_punica_ops_sizes.py - # triton op - torch.ops.vllm.sgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - #torch op - sgmv_shrink( - inputs_tensor, - lora_weights, - ref_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - else: - torch.ops.vllm.sgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - add_inputs=True, - ) - sgmv_expand( - inputs_tensor, - lora_weights, - ref_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - add_inputs=True, - ) - - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) -======= # Preventing cache error pointer. with _dict_lock: _LORA_A_PTR_DICT.clear() - sgmv_shrink( + torch.ops.vllm.sgmv_shrink( inputs_tensor, lora_weights_lst, our_out_tensor, @@ -255,20 +185,23 @@ def test_punica_sgmv( scaling, ) for index in range(nslices): - ref_torch_groupgemm( - ref_out_tensor[index], + sgmv_shrink( inputs_tensor, lora_weights_lst[index], - lora_indices_tensor, + ref_out_tensor[index], + b_seq_start_loc, seq_len_tensor, + lora_indices_tensor, batches, + max_seq_length, + token_nums, scaling, - op_type, ) + else: with _dict_lock: _LORA_B_PTR_DICT.clear() - sgmv_expand( + torch.ops.vllm.sgmv_expand( inputs_tensor, lora_weights_lst, our_out_tensor, @@ -281,23 +214,40 @@ def test_punica_sgmv( offset_start=0, add_inputs=True, ) - - slice_offset = 0 - for index in range(nslices): - lora_weights = lora_weights_lst[index] - ref_torch_groupgemm( - ref_out_tensor[:, slice_offset:slice_offset + hidden_size], - inputs_tensor[index], - lora_weights, - lora_indices_tensor, + if nslices==1: + # Verify the torch's sgmv_expand op + sgmv_expand( + inputs_tensor[0], + lora_weights_lst[0], + ref_out_tensor, + b_seq_start_loc, seq_len_tensor, + lora_indices_tensor, batches, - 1.0, - op_type, + max_seq_length, + token_nums, + add_inputs=True, ) - slice_offset += hidden_size + else: + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + sgmv_expand_slice( + inputs_tensor[index], + lora_weights, + ref_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + slice_offset, + hidden_size, + add_inputs=True, + ) + slice_offset += hidden_size ->>>>>>> main:tests/lora/test_punica_sizes.py assert_close(our_out_tensor, ref_out_tensor) @@ -389,13 +339,8 @@ def test_punica_bgmv( @pytest.mark.parametrize("nslices", [2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEED) -<<<<<<< HEAD:tests/lora/test_punica_ops_sizes.py @pytest.mark.parametrize("device", DEVICES) -def test_punica_expand_nslices( -======= -@pytest.mark.parametrize("device", CUDA_DEVICES) def test_punica_bgmv_expand_nslices( ->>>>>>> main:tests/lora/test_punica_sizes.py batches: int, num_loras: int, rank: int, @@ -431,58 +376,7 @@ def test_punica_bgmv_expand_nslices( slice_offset = 0 for index in range(nslices): lora_weights = lora_weights_lst[index] -<<<<<<< HEAD:tests/lora/test_punica_ops_sizes.py - if op_type == "sgmv": - torch.ops.vllm.sgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - sgmv_expand_slice( - inputs_tensor, - lora_weights, - ref_outputs, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - else: - - torch.ops.vllm.bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) - bgmv_expand_slice( - inputs_tensor, - lora_weights, - ref_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) -======= - bgmv_expand_slice( + torch.ops.vllm.bgmv_expand_slice( inputs_tensor, lora_weights, our_outputs, @@ -491,17 +385,15 @@ def test_punica_bgmv_expand_nslices( slice_size=hidden_size, add_inputs=True, ) - ref_torch_groupgemm( - ref_outputs[:, slice_offset:slice_offset + hidden_size], - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - 1.0, - op_type="expand", - ) ->>>>>>> main:tests/lora/test_punica_sizes.py + bgmv_expand_slice( + inputs_tensor, + lora_weights, + ref_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) slice_offset += hidden_size assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/test_punica_ops_variation.py b/tests/lora/test_punica_ops_variation.py index 3e1a32ec57e2..36c103e50ab2 100644 --- a/tests/lora/test_punica_ops_variation.py +++ b/tests/lora/test_punica_ops_variation.py @@ -9,27 +9,16 @@ import torch # Enable custom op register -<<<<<<< HEAD:tests/lora/test_punica_ops_variation.py import vllm.lora.ops.triton_ops # noqa: F401 +from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.platforms import current_platform from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink) -from vllm.platforms import current_platform - -from .utils import generate_data, generate_data_for_expand_nslices -======= -import vllm.lora.ops.bgmv_expand -import vllm.lora.ops.bgmv_expand_slice -import vllm.lora.ops.bgmv_shrink -import vllm.lora.ops.sgmv_expand -import vllm.lora.ops.sgmv_shrink # noqa: F401 -from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT -from vllm.platforms import current_platform from .utils import (assert_close, generate_data, generate_data_for_expand_nslices, - generate_data_for_nslices, ref_torch_groupgemm) ->>>>>>> main:tests/lora/test_punica_variation.py + generate_data_for_nslices) HIDDEN_SIZES = [4097] @@ -39,30 +28,9 @@ MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256] SCALES = [0.5] SEED = [0] - -DEVICES = ["cuda:0"] - -<<<<<<< HEAD:tests/lora/test_punica_ops_variation.py - -def assert_close(a, b): - rtol, atol = { - torch.float16: (6e-2, 6e-2), - torch.bfloat16: (6e-2, 6e-2), - torch.float32: (1e-2, 1e-2), - }[a.dtype] - torch.testing.assert_close(a, b, rtol=rtol, atol=atol) -======= -# Unlike test_punica_sizes.py, we directly utilize custom op for -# testing, which verifies the correct registration of these ops. -bgmv_expand = torch.ops.vllm.bgmv_expand -bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice -bgmv_shrink = torch.ops.vllm.bgmv_shrink -sgmv_expand = torch.ops.vllm.sgmv_expand -sgmv_shrink = torch.ops.vllm.sgmv_shrink +DEVICES = [f"cuda:{0}"] _dict_lock = Lock() ->>>>>>> main:tests/lora/test_punica_variation.py - @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @@ -71,7 +39,7 @@ def assert_close(a, b): @pytest.mark.parametrize("scaling", SCALES) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( @@ -117,65 +85,10 @@ def test_punica_sgmv( else: max_seq_length = max_seq_length.item() if op_type == "shrink": -<<<<<<< HEAD:tests/lora/test_punica_ops_variation.py - # triton op - torch.ops.vllm.sgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - # torch op - sgmv_shrink( - inputs_tensor, - lora_weights, - ref_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) - else: - torch.ops.vllm.sgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - add_inputs=True, - ) - sgmv_expand( - inputs_tensor, - lora_weights, - ref_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - add_inputs=True, - ) - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) -======= # Preventing cache error pointer. with _dict_lock: _LORA_A_PTR_DICT.clear() - sgmv_shrink( + torch.ops.vllm.sgmv_shrink( inputs_tensor, lora_weights_lst, our_out_tensor, @@ -188,20 +101,23 @@ def test_punica_sgmv( scaling, ) for index in range(nslices): - ref_torch_groupgemm( - ref_out_tensor[index], + sgmv_shrink( inputs_tensor, lora_weights_lst[index], - lora_indices_tensor, + ref_out_tensor[index], + b_seq_start_loc, seq_len_tensor, + lora_indices_tensor, batches, + max_seq_length, + token_nums, scaling, - op_type, ) + else: with _dict_lock: _LORA_B_PTR_DICT.clear() - sgmv_expand( + torch.ops.vllm.sgmv_expand( inputs_tensor, lora_weights_lst, our_out_tensor, @@ -214,26 +130,42 @@ def test_punica_sgmv( offset_start=0, add_inputs=True, ) - slice_offset = 0 - for index in range(nslices): - lora_weights = lora_weights_lst[index] - ref_torch_groupgemm( - ref_out_tensor[:, slice_offset:slice_offset + hidden_size], - inputs_tensor[index], - lora_weights, - lora_indices_tensor, + if nslices==1: + # Verify the torch's sgmv_expand op + sgmv_expand( + inputs_tensor[0], + lora_weights_lst[0], + ref_out_tensor, + b_seq_start_loc, seq_len_tensor, + lora_indices_tensor, batches, - 1.0, - op_type, + max_seq_length, + token_nums, + add_inputs=True, ) - slice_offset += hidden_size + else: + for index in range(nslices): + lora_weights = lora_weights_lst[index] + sgmv_expand_slice( + inputs_tensor[index], + lora_weights, + ref_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + slice_offset, + hidden_size, + add_inputs=True, + ) + slice_offset += hidden_size ->>>>>>> main:tests/lora/test_punica_variation.py assert_close(our_out_tensor, ref_out_tensor) - @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) @@ -254,7 +186,6 @@ def test_punica_bgmv( seed: int, device: str, ): - torch.set_default_device(device) current_platform.seed_everything(seed) @@ -286,6 +217,7 @@ def test_punica_bgmv( indices, scaling, ) + bgmv_shrink( inputs_tensor, lora_weights, @@ -294,14 +226,7 @@ def test_punica_bgmv( scaling, ) - # bgmv_expand = torch.ops.vllm.bgmv_expand - # bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice - # bgmv_shrink = torch.ops.vllm.bgmv_shrink - # sgmv_expand = torch.ops.vllm.sgmv_expand - # sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice - # sgmv_shrink = torch.ops.vllm.sgmv_shrink else: - torch.ops.vllm.bgmv_expand( inputs_tensor, lora_weights, @@ -316,16 +241,7 @@ def test_punica_bgmv( indices, add_inputs=True, ) - # ref_torch_groupgemm( - # ref_out_tensor, - # inputs_tensor, - # lora_weights, - # lora_indices_tensor, - # seq_len_tensor, - # batches, - # scaling if op_type == "shrink" else 1.0, - # op_type, - # ) + if op_type == "shrink": ref_out_tensor = ref_out_tensor.to(torch.float32) assert_close(our_out_tensor, ref_out_tensor) @@ -338,13 +254,8 @@ def test_punica_bgmv( @pytest.mark.parametrize("nslices", [2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEED) -<<<<<<< HEAD:tests/lora/test_punica_ops_variation.py @pytest.mark.parametrize("device", DEVICES) -def test_punica_expand_nslices( -======= -@pytest.mark.parametrize("device", CUDA_DEVICES) def test_punica_bgmv_expand_nslices( ->>>>>>> main:tests/lora/test_punica_variation.py batches: int, num_loras: int, rank: int, @@ -380,58 +291,7 @@ def test_punica_bgmv_expand_nslices( slice_offset = 0 for index in range(nslices): lora_weights = lora_weights_lst[index] -<<<<<<< HEAD:tests/lora/test_punica_ops_variation.py - if op_type == "sgmv": - torch.ops.vllm.sgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - - sgmv_expand_slice( - inputs_tensor, - lora_weights, - ref_outputs, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - else: - torch.ops.vllm.bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) - bgmv_expand_slice( - inputs_tensor, - lora_weights, - ref_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) -======= - bgmv_expand_slice( + torch.ops.vllm.bgmv_expand_slice( inputs_tensor, lora_weights, our_outputs, @@ -440,17 +300,15 @@ def test_punica_bgmv_expand_nslices( slice_size=hidden_size, add_inputs=True, ) - ref_torch_groupgemm( - ref_outputs[:, slice_offset:slice_offset + hidden_size], - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - 1.0, - op_type="expand", - ) + bgmv_expand_slice( + inputs_tensor, + lora_weights, + ref_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) ->>>>>>> main:tests/lora/test_punica_variation.py slice_offset += hidden_size assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b66d18074a7b..ce47546f2154 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -104,33 +104,6 @@ def assert_close(a, b): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) -def ref_torch_groupgemm( - out_tensor, - inputs, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - scaling, - op_type, -) -> torch.Tensor: - out_list = [] - current_offset = 0 - for lora_index, b_length in zip(range(batches), seq_len_tensor): - input_weight = inputs[current_offset:b_length + current_offset, :] - current_offset += b_length - lora_weight = lora_weights[lora_indices_tensor[lora_index]] - result = torch.nn.functional.linear(input_weight, lora_weight) - result *= scaling - out_list.append(result) - cat_result = torch.cat(out_list, dim=0) - if op_type == "expand": - out_tensor += cat_result - else: - out_tensor.copy_(cat_result) - return - - def generate_data( batches, hidden_size, diff --git a/vllm/lora/ops/triton_ops/__init__.py b/vllm/lora/ops/triton_ops/__init__.py index 70d08c682713..9805b6dd5038 100644 --- a/vllm/lora/ops/triton_ops/__init__.py +++ b/vllm/lora/ops/triton_ops/__init__.py @@ -2,7 +2,6 @@ from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.triton_ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink # noqa: F401 __all__ = [ @@ -10,6 +9,5 @@ "bgmv_expand_slice", "bgmv_shrink", "sgmv_expand", - "sgmv_expand_slice", "sgmv_shrink", ] From 61c42cfb1f0f03e11766fe8f93b78bf307437ad1 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 9 Jan 2025 06:41:40 +0000 Subject: [PATCH 15/23] Sync main Signed-off-by: Jee Jee Li --- vllm/lora/punica_wrapper/punica_selector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index 85494e4cb160..9f1606e672de 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -15,7 +15,7 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: elif current_platform.is_cpu(): # Lazy import to avoid ImportError from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU - print_info_once("Using PunicaWrapperCPU.") + logger.info_once("Using PunicaWrapperCPU.") return PunicaWrapperCPU(*args, **kwargs) elif current_platform.is_hpu(): # Lazy import to avoid ImportError From 0d19f0317bb4be58643f2a14adda05ddf148be2a Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 9 Jan 2025 06:44:45 +0000 Subject: [PATCH 16/23] Make isort happy Signed-off-by: Jee Jee Li --- tests/lora/test_punica_ops_sizes.py | 5 +++-- tests/lora/test_punica_ops_variation.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/lora/test_punica_ops_sizes.py b/tests/lora/test_punica_ops_sizes.py index 47f652392650..9ec72e48959c 100644 --- a/tests/lora/test_punica_ops_sizes.py +++ b/tests/lora/test_punica_ops_sizes.py @@ -10,11 +10,12 @@ import torch import vllm.lora.ops.triton_ops # noqa: F401 -from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT -from vllm.platforms import current_platform from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.platforms import current_platform + from .utils import (assert_close, generate_data, generate_data_for_expand_nslices, generate_data_for_nslices) diff --git a/tests/lora/test_punica_ops_variation.py b/tests/lora/test_punica_ops_variation.py index 36c103e50ab2..6d3d79c2d7b8 100644 --- a/tests/lora/test_punica_ops_variation.py +++ b/tests/lora/test_punica_ops_variation.py @@ -10,11 +10,11 @@ # Enable custom op register import vllm.lora.ops.triton_ops # noqa: F401 -from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT -from vllm.platforms import current_platform from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, bgmv_shrink, sgmv_expand, sgmv_expand_slice, sgmv_shrink) +from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.platforms import current_platform from .utils import (assert_close, generate_data, generate_data_for_expand_nslices, From 5af9cbb9e006781f5585c402ff29d87d9aa65345 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 9 Jan 2025 06:52:23 +0000 Subject: [PATCH 17/23] Make yapf happy Signed-off-by: Jee Jee Li --- tests/lora/test_punica_ops_sizes.py | 22 +++++++++++----------- tests/lora/test_punica_ops_variation.py | 24 +++++++++++++----------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/tests/lora/test_punica_ops_sizes.py b/tests/lora/test_punica_ops_sizes.py index 9ec72e48959c..433ca7577d08 100644 --- a/tests/lora/test_punica_ops_sizes.py +++ b/tests/lora/test_punica_ops_sizes.py @@ -124,7 +124,7 @@ @pytest.mark.parametrize("scaling", SCALES) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( @@ -198,7 +198,7 @@ def test_punica_sgmv( token_nums, scaling, ) - + else: with _dict_lock: _LORA_B_PTR_DICT.clear() @@ -215,7 +215,7 @@ def test_punica_sgmv( offset_start=0, add_inputs=True, ) - if nslices==1: + if nslices == 1: # Verify the torch's sgmv_expand op sgmv_expand( inputs_tensor[0], @@ -387,14 +387,14 @@ def test_punica_bgmv_expand_nslices( add_inputs=True, ) bgmv_expand_slice( - inputs_tensor, - lora_weights, - ref_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) + inputs_tensor, + lora_weights, + ref_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) slice_offset += hidden_size assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/test_punica_ops_variation.py b/tests/lora/test_punica_ops_variation.py index 6d3d79c2d7b8..2583da3fb6c0 100644 --- a/tests/lora/test_punica_ops_variation.py +++ b/tests/lora/test_punica_ops_variation.py @@ -32,6 +32,7 @@ _dict_lock = Lock() + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) @@ -39,7 +40,7 @@ @pytest.mark.parametrize("scaling", SCALES) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( @@ -113,7 +114,7 @@ def test_punica_sgmv( token_nums, scaling, ) - + else: with _dict_lock: _LORA_B_PTR_DICT.clear() @@ -131,7 +132,7 @@ def test_punica_sgmv( add_inputs=True, ) slice_offset = 0 - if nslices==1: + if nslices == 1: # Verify the torch's sgmv_expand op sgmv_expand( inputs_tensor[0], @@ -166,6 +167,7 @@ def test_punica_sgmv( assert_close(our_out_tensor, ref_out_tensor) + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) @@ -301,14 +303,14 @@ def test_punica_bgmv_expand_nslices( add_inputs=True, ) bgmv_expand_slice( - inputs_tensor, - lora_weights, - ref_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) + inputs_tensor, + lora_weights, + ref_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) slice_offset += hidden_size assert_close(our_outputs, ref_outputs) From 202aca364ad34c9c32bba72d674341e6b433102c Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 10 Jan 2025 15:09:03 +0000 Subject: [PATCH 18/23] run-cpu-test.sh now runs multi-lora tests Signed-off-by: Akshat Tripathi --- .buildkite/run-cpu-test.sh | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 1a4dae8f65e9..85a27cd0e30d 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -75,6 +75,13 @@ function cpu_tests() { --num-prompts 20 \ --endpoint /v1/completions \ --tokenizer facebook/opt-125m" + + # Run multi-lora tests + docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " + set -e + pytest -s -v \ + tests/lora/test_qwen2vl.py \ + tests/lora/test_lora_bias_e2e.py" } # All of CPU tests are expected to be finished less than 25 mins. From 76444f4229ca87ce9a3d26820eed1bbd34964f8a Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 10 Jan 2025 16:47:52 +0000 Subject: [PATCH 19/23] Updated compatibility docs Signed-off-by: Akshat Tripathi --- docs/source/features/compatibility_matrix.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/features/compatibility_matrix.md b/docs/source/features/compatibility_matrix.md index 8d8f7dca2e5b..ea1d545ff3d7 100644 --- a/docs/source/features/compatibility_matrix.md +++ b/docs/source/features/compatibility_matrix.md @@ -359,7 +359,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar - ✅ - ✅ - ✅ - - [✗](gh-pr:4830) + - ✅ - ✅ * - prmpt adptr - ✅ From 64c700f79185237c7c74945ab0bef451454e7265 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 10 Jan 2025 17:06:18 +0000 Subject: [PATCH 20/23] Update .buildkite/run-cpu-test.sh Co-authored-by: Isotr0py <2037008807@qq.com> Signed-off-by: Akshat Tripathi --- .buildkite/run-cpu-test.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 85a27cd0e30d..ca5e8083831e 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -81,7 +81,6 @@ function cpu_tests() { set -e pytest -s -v \ tests/lora/test_qwen2vl.py \ - tests/lora/test_lora_bias_e2e.py" } # All of CPU tests are expected to be finished less than 25 mins. From 06399bc3b27b9c9153862c15bc199c13197f6af8 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 10 Jan 2025 17:08:12 +0000 Subject: [PATCH 21/23] Lint Signed-off-by: Akshat Tripathi --- .buildkite/run-cpu-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index ca5e8083831e..b9f005ade900 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -80,7 +80,7 @@ function cpu_tests() { docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " set -e pytest -s -v \ - tests/lora/test_qwen2vl.py \ + tests/lora/test_qwen2vl.py } # All of CPU tests are expected to be finished less than 25 mins. From 69eb3dc65f7f5e2eb186ad1c91ab847f0babc453 Mon Sep 17 00:00:00 2001 From: Akshat Tripathi Date: Fri, 10 Jan 2025 17:11:06 +0000 Subject: [PATCH 22/23] Lint2 Signed-off-by: Akshat Tripathi --- .buildkite/run-cpu-test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index b9f005ade900..f190aebf1961 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -80,7 +80,7 @@ function cpu_tests() { docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " set -e pytest -s -v \ - tests/lora/test_qwen2vl.py + tests/lora/test_qwen2vl.py" } # All of CPU tests are expected to be finished less than 25 mins. From d8cb9afda54b259075808faa1467664b40da0700 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 12 Jan 2025 03:34:03 +0000 Subject: [PATCH 23/23] Reduce test memory Signed-off-by: Jee Jee Li --- tests/lora/test_punica_ops_variation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lora/test_punica_ops_variation.py b/tests/lora/test_punica_ops_variation.py index 2583da3fb6c0..2bb84c1cf11e 100644 --- a/tests/lora/test_punica_ops_variation.py +++ b/tests/lora/test_punica_ops_variation.py @@ -20,7 +20,7 @@ generate_data_for_expand_nslices, generate_data_for_nslices) -HIDDEN_SIZES = [4097] +HIDDEN_SIZES = [2049] BATCHES = [1, 4, 16, 32] NUM_LORA = [1, 8, 32, 128]