1111from  vllm .distributed .kv_transfer .kv_connector .v1 .base  import  (
1212    KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole )
1313from  vllm .logger  import  init_logger 
14+ from  vllm .v1 .attention .backends .mla .common  import  MLACommonMetadata 
1415from  vllm .v1 .core .sched .output  import  SchedulerOutput 
1516
1617if  TYPE_CHECKING :
@@ -98,6 +99,7 @@ def start_load_kv(self, forward_context: "ForwardContext",
9899            The number of elements in kv_caches and layer_names should be  
99100            the same. 
100101        """ 
102+         attn_metadata  =  forward_context .attn_metadata 
101103
102104        def  inject_kv_into_layer (
103105            dst_kv_cache_layer : torch .Tensor ,
@@ -108,19 +110,29 @@ def inject_kv_into_layer(
108110
109111            Args: 
110112                dst_kv_cache_layer (torch.Tensor): the destination KV cache  
111-                     layer. In shape [2, num_pages, page_size, xxx]. 
113+                     layer. In shape [2, num_pages, page_size, xxx] if not  
114+                     using MLA, [num_pages, page_size, xxx] otherwise. 
112115                src_kv_cache (torch.Tensor): the source KV cache. In shape 
113-                     [2, num_tokens, xxx]. 
116+                     [2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]  
117+                     otherwise. 
114118                slot_mapping (torch.Tensor): the slot mapping. In shape  
115119                    [num_tokens]. 
116120            """ 
117121            dst_kv_cache_layer_shape  =  dst_kv_cache_layer .shape 
118-             num_pages  =  dst_kv_cache_layer_shape [1 ]
119-             page_size  =  dst_kv_cache_layer_shape [2 ]
120-             dst_kv_cache_layer  =  dst_kv_cache_layer .reshape (
121-                 2 , num_pages  *  page_size , - 1 )
122-             dst_kv_cache_layer [:, slot_mapping , ...] =  src_kv_cache 
123-             dst_kv_cache_layer .reshape (dst_kv_cache_layer_shape )
122+             if  isinstance (attn_metadata , MLACommonMetadata ):
123+                 num_pages  =  dst_kv_cache_layer_shape [0 ]
124+                 page_size  =  dst_kv_cache_layer_shape [1 ]
125+                 dst_kv_cache_layer  =  dst_kv_cache_layer .reshape (
126+                     num_pages  *  page_size , - 1 )
127+                 dst_kv_cache_layer [slot_mapping , ...] =  src_kv_cache 
128+                 dst_kv_cache_layer .reshape (dst_kv_cache_layer_shape )
129+             else :
130+                 num_pages  =  dst_kv_cache_layer_shape [1 ]
131+                 page_size  =  dst_kv_cache_layer_shape [2 ]
132+                 dst_kv_cache_layer  =  dst_kv_cache_layer .reshape (
133+                     2 , num_pages  *  page_size , - 1 )
134+                 dst_kv_cache_layer [:, slot_mapping , ...] =  src_kv_cache 
135+                 dst_kv_cache_layer .reshape (dst_kv_cache_layer_shape )
124136
125137        # Get the metadata 
126138        metadata : KVConnectorMetadata  =  \
@@ -170,7 +182,7 @@ def wait_for_layer_load(self, layer_name: str) -> None:
170182
171183    def  save_kv_layer (self , layer_name : str , kv_layer : torch .Tensor ,
172184                      attn_metadata : "AttentionMetadata" , ** kwargs ) ->  None :
173-         """Start saving the a layer  of KV cache  from vLLM's paged buffer  
185+         """Start saving the KV cache  of the layer  from vLLM's paged buffer  
174186        to the connector. 
175187
176188        Args: 
@@ -187,10 +199,13 @@ def extract_kv_from_layer(
187199        ) ->  torch .Tensor :
188200            """Extract the KV cache from the layer. 
189201
190-             Assume the shape of the layer is (2, num_pages, page_size, xxx). 
202+             Assume the shape of the layer is (2, num_pages, page_size, xxx) 
203+             if MLA is not used, and (num_pages, page_size, xxx) otherwise. 
191204            """ 
192-             # TODO: make this compatible with MLA. 
193-             assert  layer .shape [0 ] ==  2 
205+             if  isinstance (attn_metadata , MLACommonMetadata ):
206+                 num_pages , page_size  =  layer .shape [0 ], layer .shape [1 ]
207+                 return  layer .reshape (num_pages  *  page_size , - 1 )[slot_mapping ,
208+                                                                 ...]
194209            num_pages , page_size  =  layer .shape [1 ], layer .shape [2 ]
195210            return  layer .reshape (2 , num_pages  *  page_size , - 1 )[:, slot_mapping ,
196211                                                               ...]
0 commit comments