Skip to content

Commit 2cb84f2

Browse files
committed
support pp virtual engine
Signed-off-by: Chen Zhang <[email protected]>
1 parent ffe8cdd commit 2cb84f2

13 files changed

+56
-39
lines changed

vllm/attention/layer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,11 @@ def __init__(
123123
self.attn_type = attn_type
124124
# use a placeholder kv cache tensor during init, which will be replaced
125125
# by bind_kv_cache
126-
self.kv_cache = torch.tensor([])
126+
# this variable will not be accessed if use_direct_call is True
127+
self.kv_cache = [
128+
torch.tensor([]) for _ in range(get_current_vllm_config(
129+
).parallel_config.pipeline_parallel_size)
130+
]
127131

128132
def forward(
129133
self,
@@ -238,7 +242,8 @@ def unified_attention(
238242
forward_context: ForwardContext = get_forward_context()
239243
attn_metadata = forward_context.attn_metadata
240244
self = forward_context.attn_layers[layer_name]
241-
return self.impl.forward(query, key, value, self.kv_cache, attn_metadata,
245+
kv_cache = self.kv_cache[forward_context.virtual_engine]
246+
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
242247
self._k_scale, self._v_scale)
243248

244249

@@ -270,10 +275,11 @@ def unified_attention_with_output(
270275
forward_context: ForwardContext = get_forward_context()
271276
attn_metadata = forward_context.attn_metadata
272277
self = forward_context.attn_layers[layer_name]
278+
kv_cache = self.kv_cache[forward_context.virtual_engine]
273279
self.impl.forward(query,
274280
key,
275281
value,
276-
self.kv_cache,
282+
kv_cache,
277283
attn_metadata,
278284
self._k_scale,
279285
self._v_scale,

vllm/forward_context.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class ForwardContext:
2828
attn_layers: Dict[str, Any]
2929
# TODO: extend to support per-layer dynamic forward context
3030
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
31+
# TODO: remove after making all virtual_engines share the same kv cache
32+
virtual_engine: int # set dynamically for each forward pass
3133

3234

3335
_forward_context: Optional[ForwardContext] = None
@@ -42,7 +44,9 @@ def get_forward_context() -> ForwardContext:
4244

4345

4446
@contextmanager
45-
def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig):
47+
def set_forward_context(attn_metadata: Any,
48+
vllm_config: VllmConfig,
49+
virtual_engine: int = 0):
4650
"""A context manager that stores the current forward context,
4751
can be attention metadata, etc.
4852
Here we can inject common logic for every model forward pass.
@@ -55,6 +59,7 @@ def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig):
5559
prev_context = _forward_context
5660
_forward_context = ForwardContext(
5761
attn_layers=vllm_config.compilation_config.static_forward_context,
62+
virtual_engine=virtual_engine,
5863
attn_metadata=attn_metadata)
5964
try:
6065
yield

vllm/utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1949,16 +1949,20 @@ def get_mp_context():
19491949
return multiprocessing.get_context(mp_method)
19501950

19511951

1952-
def bind_kv_cache(ctx: Dict[str, Any], kv_cache: List[torch.Tensor]) -> None:
1952+
def bind_kv_cache(
1953+
ctx: Dict[str, Any],
1954+
kv_cache: List[List[torch.Tensor]], # [virtual_engine][layer_index]
1955+
) -> None:
19531956
# Bind the kv_cache tensor to Attention modules, similar to
1954-
# ctx[layer_name].kv_cache = kv_cache[extract_layer_index(layer_name)]
1957+
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
19551958
# Special things handled here:
19561959
# 1. Some models have non-attention layers, e.g., Jamba
19571960
# 2. Pipeline parallelism, each rank only has a subset of layers
19581961
# 3. Encoder attention has no kv cache
19591962
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
19601963
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
1961-
# and decoder.layers.1.encoder_attn is mapped to the same kv cache tensor
1964+
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
1965+
# tensor
19621966
from vllm.attention import AttentionType
19631967
from vllm.model_executor.models.utils import extract_layer_index
19641968
layer_need_kv_cache = [
@@ -1974,4 +1978,7 @@ def bind_kv_cache(ctx: Dict[str, Any], kv_cache: List[torch.Tensor]) -> None:
19741978
kv_cache_idx = layer_index_sorted.index(
19751979
extract_layer_index(layer_name))
19761980
forward_ctx = ctx[layer_name]
1977-
forward_ctx.kv_cache = kv_cache[kv_cache_idx]
1981+
assert len(forward_ctx.kv_cache) == len(kv_cache)
1982+
for ve, ve_kv_cache in enumerate(kv_cache):
1983+
assert forward_ctx.kv_cache[ve].numel() == 0
1984+
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,4 +863,4 @@ def initialize_kv_cache(self, num_blocks: int) -> None:
863863
device=self.device))
864864
bind_kv_cache(
865865
self.vllm_config.compilation_config.static_forward_context,
866-
self.kv_caches)
866+
[self.kv_caches])

vllm/worker/cache_engine.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
import torch
55

66
from vllm.attention import get_attn_backend
7-
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
8-
ModelConfig, ParallelConfig)
7+
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
98
from vllm.logger import init_logger
109
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType,
11-
bind_kv_cache, get_dtype_size, is_pin_memory_available)
10+
get_dtype_size, is_pin_memory_available)
1211

1312
logger = init_logger(__name__)
1413

@@ -21,14 +20,9 @@ class CacheEngine:
2120
as swapping and copying.
2221
"""
2322

24-
def __init__(
25-
self,
26-
cache_config: CacheConfig,
27-
model_config: ModelConfig,
28-
parallel_config: ParallelConfig,
29-
device_config: DeviceConfig,
30-
compilation_config: CompilationConfig,
31-
) -> None:
23+
def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
24+
parallel_config: ParallelConfig,
25+
device_config: DeviceConfig) -> None:
3226
self.cache_config = cache_config
3327
self.model_config = model_config
3428
self.parallel_config = parallel_config
@@ -64,8 +58,6 @@ def __init__(
6458
self.gpu_cache = self._allocate_kv_cache(
6559
self.num_gpu_blocks, self.device_config.device_type)
6660
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
67-
bind_kv_cache(compilation_config.static_forward_context,
68-
self.gpu_cache)
6961

7062
def _allocate_kv_cache(
7163
self,

vllm/worker/cpu_enc_dec_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def execute_model(
305305
intermediate_tensors,
306306
}
307307

308-
with set_forward_context(model_input.attn_metadata, self.vllm_config):
308+
with set_forward_context(model_input.attn_metadata, self.vllm_config,
309+
model_input.virtual_engine):
309310
hidden_states = model_executable(**execute_model_kwargs)
310311

311312
# Compute the logits.

vllm/worker/cpu_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,8 @@ def execute_model(
526526
execute_model_kwargs.update(
527527
{"previous_hidden_states": previous_hidden_states})
528528

529-
with set_forward_context(model_input.attn_metadata, self.vllm_config):
529+
with set_forward_context(model_input.attn_metadata, self.vllm_config,
530+
model_input.virtual_engine):
530531
hidden_states = model_executable(
531532
input_ids=model_input.input_tokens,
532533
positions=model_input.input_positions,

vllm/worker/cpu_pooling_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def execute_model(
6969
intermediate_tensors,
7070
}
7171

72-
with set_forward_context(model_input.attn_metadata, self.vllm_config):
72+
with set_forward_context(model_input.attn_metadata, self.vllm_config,
73+
model_input.virtual_engine):
7374
hidden_states = model_executable(**execute_model_kwargs)
7475

7576
# Only perform pooling in the driver worker.

vllm/worker/cpu_worker.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
import vllm.envs as envs
88
from vllm.attention import get_attn_backend
9-
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
10-
ModelConfig, ParallelConfig, VllmConfig)
9+
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
10+
ParallelConfig, VllmConfig)
1111
from vllm.distributed import (ensure_model_parallel_initialized,
1212
init_distributed_environment)
1313
from vllm.logger import init_logger
@@ -33,8 +33,8 @@ class CPUCacheEngine:
3333
"""
3434

3535
def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
36-
parallel_config: ParallelConfig, device_config: DeviceConfig,
37-
compilation_config: CompilationConfig) -> None:
36+
parallel_config: ParallelConfig,
37+
device_config: DeviceConfig) -> None:
3838
assert device_config.device_type == "cpu"
3939
self.cache_config = cache_config
4040
self.model_config = model_config
@@ -66,8 +66,6 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
6666

6767
# Initialize the cache.
6868
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
69-
bind_kv_cache(compilation_config.static_forward_context,
70-
self.cpu_cache)
7169

7270
def _allocate_kv_cache(
7371
self,
@@ -292,13 +290,15 @@ def _init_cache_engine(self) -> None:
292290
self.model_config,
293291
self.parallel_config,
294292
self.device_config,
295-
self.compilation_config,
296293
) for _ in range(self.parallel_config.pipeline_parallel_size)
297294
]
298295
self.cpu_cache = [
299296
self.cache_engine[ve].cpu_cache
300297
for ve in range(self.parallel_config.pipeline_parallel_size)
301298
]
299+
for ve in range(self.parallel_config.pipeline_parallel_size):
300+
bind_kv_cache(self.compilation_config.static_forward_context,
301+
self.cpu_cache[ve], ve)
302302
self.model_runner.block_size = self.cache_engine[0].block_size
303303

304304
assert all(

vllm/worker/enc_dec_model_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ def execute_model(
175175
} if self.has_inner_state else {}
176176

177177
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
178-
with set_forward_context(model_input.attn_metadata, self.vllm_config):
178+
with set_forward_context(model_input.attn_metadata, self.vllm_config,
179+
model_input.virtual_engine):
179180
hidden_or_intermediate_states = model_executable(
180181
input_ids=model_input.input_tokens,
181182
positions=model_input.input_positions,

0 commit comments

Comments
 (0)