diff --git a/tests/v1/test_utils.py b/vllm/v1/test_utils.py similarity index 97% rename from tests/v1/test_utils.py rename to vllm/v1/test_utils.py index b68f08385866..a488f09f6526 100644 --- a/tests/v1/test_utils.py +++ b/vllm/v1/test_utils.py @@ -2,7 +2,7 @@ import torch -from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.tpu_model_runner import bind_kv_cache def test_bind_kv_cache(): diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 8e1fb18cca05..2820e2f37880 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -3,20 +3,14 @@ import multiprocessing import os import weakref -from collections import defaultdict from collections.abc import Sequence -from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, - Union, overload) +from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload import torch from vllm.logger import init_logger -from vllm.model_executor.models.utils import extract_layer_index from vllm.utils import get_mp_context, kill_process_tree -if TYPE_CHECKING: - from vllm.attention.layer import Attention - logger = init_logger(__name__) T = TypeVar("T") @@ -145,51 +139,6 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): os.remove(socket_file) -def bind_kv_cache( - kv_caches: dict[str, torch.Tensor], - forward_context: dict[str, "Attention"], - runner_kv_caches: list[torch.Tensor], -) -> None: - """ - Bind the allocated KV cache to both ModelRunner and forward context so - that the KV cache can be used in the forward pass. - - This function: - 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with - kv_caches. - 2) Associates each attention layer in the `forward_context` with its - corresponding KV cache in kv_caches. - - Args: - kv_caches: The allocated kv_caches with layer names as keys. - forward_context: The global forward context containing all Attention - layers with layer names as keys. - runner_kv_caches: The kv_cache declared by ModelRunner. - """ - # Bind kv_caches to ModelRunner - assert len(runner_kv_caches) == 0 - - # Convert kv_caches dict to a list of tensors in the order of layer_index. - index2name = defaultdict(list) - for layer_name in kv_caches: - index2name[extract_layer_index(layer_name)].append(layer_name) - - for layer_index in sorted(index2name.keys()): - layer_names = index2name[layer_index] - if len(layer_names) > 1: - # One typical case is encoder-decoder model, e.g., bart. - # The cross attention and self attention in the same decoder layer - # has different layer_name but the same layer_index. - raise NotImplementedError - layer_name = layer_names[0] - runner_kv_caches.append(kv_caches[layer_name]) - - # Bind kv_caches to forward context - for layer_name, kv_cache in kv_caches.items(): - # NOTE: Use list because of v0 PP virtual engine. - forward_context[layer_name].kv_cache = [kv_cache] - - def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int) -> torch.Tensor: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4a1fb0514c3f..9088c7eab73f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -34,7 +34,6 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -135,7 +134,6 @@ def __init__( # Lazy initialization # self.model: nn.Module # Set after load_model - self.kv_caches: list[torch.Tensor] = [] # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -1382,10 +1380,13 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: else: raise NotImplementedError - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) + # Associates each attention layer in the `forward_context` with the + # initialized KV cache. + forward_context = self.vllm_config.compilation_config \ + .static_forward_context + for layer_name, kv_cache in kv_caches.items(): + # NOTE: Use list because of v0 PP virtual engine. + forward_context[layer_name].kv_cache = [kv_cache] def get_kv_cache_spec(self) -> KVCacheSpec: """ diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 104e5a3dcfc5..3f2917b3227f 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import time +from collections import defaultdict from typing import TYPE_CHECKING, Optional, cast from unittest.mock import patch @@ -17,6 +18,7 @@ from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.utils import extract_layer_index from vllm.sampling_params import SamplingType from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, @@ -26,7 +28,6 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput -from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: @@ -40,6 +41,51 @@ INVALID_TOKEN_ID = -1 +def bind_kv_cache( + kv_caches: dict[str, torch.Tensor], + forward_context: dict[str, "Attention"], + runner_kv_caches: list[torch.Tensor], +) -> None: + """ + Bind the allocated KV cache to both ModelRunner and forward context so + that the KV cache can be used in the forward pass. + + This function: + 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with + kv_caches. + 2) Associates each attention layer in the `forward_context` with its + corresponding KV cache in kv_caches. + + Args: + kv_caches: The allocated kv_caches with layer names as keys. + forward_context: The global forward context containing all Attention + layers with layer names as keys. + runner_kv_caches: The kv_cache declared by ModelRunner. + """ + # Bind kv_caches to ModelRunner + assert len(runner_kv_caches) == 0 + + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in kv_caches: + index2name[extract_layer_index(layer_name)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + if len(layer_names) > 1: + # One typical case is encoder-decoder model, e.g., bart. + # The cross attention and self attention in the same decoder layer + # has different layer_name but the same layer_index. + raise NotImplementedError + layer_name = layer_names[0] + runner_kv_caches.append(kv_caches[layer_name]) + + # Bind kv_caches to forward context + for layer_name, kv_cache in kv_caches.items(): + # NOTE: Use list because of v0 PP virtual engine. + forward_context[layer_name].kv_cache = [kv_cache] + + class TPUModelRunner: def __init__( diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index cbd2fe6edd81..1ce9905fe489 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -20,8 +20,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.utils import bind_kv_cache -from vllm.v1.worker.tpu_model_runner import TPUModelRunner +from vllm.v1.worker.tpu_model_runner import TPUModelRunner, bind_kv_cache logger = init_logger(__name__)