Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/v1/test_utils.py → vllm/v1/test_utils.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file shouldn't be here (in vllm/)?

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.tpu_model_runner import bind_kv_cache


def test_bind_kv_cache():
Expand Down
53 changes: 1 addition & 52 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,14 @@
import multiprocessing
import os
import weakref
from collections import defaultdict
from collections.abc import Sequence
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload

import torch

from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.utils import get_mp_context, kill_process_tree

if TYPE_CHECKING:
from vllm.attention.layer import Attention

logger = init_logger(__name__)

T = TypeVar("T")
Expand Down Expand Up @@ -145,51 +139,6 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str):
os.remove(socket_file)


def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.

This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.

Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0

# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)

for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])

# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]


def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> torch.Tensor:
"""
Expand Down
13 changes: 7 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

Expand Down Expand Up @@ -135,7 +134,6 @@ def __init__(

# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}

Expand Down Expand Up @@ -1382,10 +1380,13 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
else:
raise NotImplementedError

bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
# Associates each attention layer in the `forward_context` with the
# initialized KV cache.
forward_context = self.vllm_config.compilation_config \
.static_forward_context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
Comment on lines +1383 to +1389
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still a kind of "binding"? To unify the interface for runners, would the following be better?

bind_kv_cache(
    kv_caches,
    self.vllm_config.compilation_config.static_forward_context,
)

And in TPU model runner:

bind_kv_cache(
    kv_caches,
    self.vllm_config.compilation_config.static_forward_context,
    self.kv_caches,
)

So we could have

def bind_kv_cache(
    kv_caches: dict[str, torch.Tensor],
    forward_context: dict[str, "Attention"],
    runner_kv_caches: Optional[list[torch.Tensor]] = None,
) -> None:
    # Bind runner kv caches. Also comment
    # this is only used by TPU for now.
    if runner_kv_caches is not None:
        ...

    # Bind forward_context.
    ...

If TPU is the only model runner (now and future) that needs this, alternatively we could have a separate utility:

def bind_runner_kv_caches(
    kv_caches: dict[str, torch.Tensor],
    runner_kv_caches: list[torch.Tensor],
)
    ...

And in TPU model runner:

bind_kv_cache(
    kv_caches,
    self.vllm_config.compilation_config.static_forward_context,
)
bind_runner_kv_caches(
    kv_caches,
    self.kv_caches,
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In tpu_model_runner, self.kv_caches is only used for detecting whether it is profile run and things like num_kv_heads, num_blocks, block_size. I think they can be removed in the future if we do some refactor on ModelWrapperV1. But I'm not very familiar with the tpu backend.

The binding to forward context is quite simple, so I prefer to implement it in initialize_kv_cache instead of calling a new function if we can remove the binding to self.kv_caches.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok if this is confirmed please add TODO/FIXME to bind_kv_cache in TPU runner and we can keep as it is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually @heheda12345 , could you check if v1's tpu_model_runner is really using the self.kv_caches? I did a quick check. It seems that

kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
never uses it, even if it's listed as an argument.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in class ModelWrapperV1(nn.Module):, and we are trying to remove it in #14309 before this PR.


def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Expand Down
48 changes: 47 additions & 1 deletion vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Optional, cast
from unittest.mock import patch

Expand All @@ -17,6 +18,7 @@
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.utils import extract_layer_index
from vllm.sampling_params import SamplingType
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
Expand All @@ -26,7 +28,6 @@
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch

if TYPE_CHECKING:
Expand All @@ -40,6 +41,51 @@
INVALID_TOKEN_ID = -1


def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.

This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.

Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0

# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)

for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])

# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]


class TPUModelRunner:

def __init__(
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
from vllm.v1.worker.tpu_model_runner import TPUModelRunner, bind_kv_cache

logger = init_logger(__name__)

Expand Down