Skip to content

Commit 406d6bf

Browse files
authored
Add MLA support for v1 disagg connector (#6)
Signed-off-by: remi <[email protected]>
1 parent 1d8415d commit 406d6bf

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
1212
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
1313
from vllm.logger import init_logger
14+
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
1415
from vllm.v1.core.sched.output import SchedulerOutput
1516

1617
if TYPE_CHECKING:
@@ -98,6 +99,7 @@ def start_load_kv(self, forward_context: "ForwardContext",
9899
The number of elements in kv_caches and layer_names should be
99100
the same.
100101
"""
102+
attn_metadata = forward_context.attn_metadata
101103

102104
def inject_kv_into_layer(
103105
dst_kv_cache_layer: torch.Tensor,
@@ -108,19 +110,29 @@ def inject_kv_into_layer(
108110
109111
Args:
110112
dst_kv_cache_layer (torch.Tensor): the destination KV cache
111-
layer. In shape [2, num_pages, page_size, xxx].
113+
layer. In shape [2, num_pages, page_size, xxx] if not
114+
using MLA, [num_pages, page_size, xxx] otherwise.
112115
src_kv_cache (torch.Tensor): the source KV cache. In shape
113-
[2, num_tokens, xxx].
116+
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
117+
otherwise.
114118
slot_mapping (torch.Tensor): the slot mapping. In shape
115119
[num_tokens].
116120
"""
117121
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
118-
num_pages = dst_kv_cache_layer_shape[1]
119-
page_size = dst_kv_cache_layer_shape[2]
120-
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
121-
2, num_pages * page_size, -1)
122-
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
123-
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
122+
if isinstance(attn_metadata, MLACommonMetadata):
123+
num_pages = dst_kv_cache_layer_shape[0]
124+
page_size = dst_kv_cache_layer_shape[1]
125+
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
126+
num_pages * page_size, -1)
127+
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
128+
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
129+
else:
130+
num_pages = dst_kv_cache_layer_shape[1]
131+
page_size = dst_kv_cache_layer_shape[2]
132+
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
133+
2, num_pages * page_size, -1)
134+
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
135+
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
124136

125137
# Get the metadata
126138
metadata: KVConnectorMetadata = \
@@ -170,7 +182,7 @@ def wait_for_layer_load(self, layer_name: str) -> None:
170182

171183
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
172184
attn_metadata: "AttentionMetadata", **kwargs) -> None:
173-
"""Start saving the a layer of KV cache from vLLM's paged buffer
185+
"""Start saving the KV cache of the layer from vLLM's paged buffer
174186
to the connector.
175187
176188
Args:
@@ -187,10 +199,13 @@ def extract_kv_from_layer(
187199
) -> torch.Tensor:
188200
"""Extract the KV cache from the layer.
189201
190-
Assume the shape of the layer is (2, num_pages, page_size, xxx).
202+
Assume the shape of the layer is (2, num_pages, page_size, xxx)
203+
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
191204
"""
192-
# TODO: make this compatible with MLA.
193-
assert layer.shape[0] == 2
205+
if isinstance(attn_metadata, MLACommonMetadata):
206+
num_pages, page_size = layer.shape[0], layer.shape[1]
207+
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
208+
...]
194209
num_pages, page_size = layer.shape[1], layer.shape[2]
195210
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
196211
...]

0 commit comments

Comments
 (0)