55import torch .nn .functional as F
66from einops import rearrange , repeat
77
8+ from vllm .model_executor .layers .mamba .mamba2_metadata import (
9+ _seq_idx_to_chunk_indices_offsets )
810from vllm .model_executor .layers .mamba .ops .ssd_combined import (
911 mamba_chunk_scan_combined )
1012from vllm .platforms import current_platform
@@ -160,14 +162,14 @@ def end_boundary(n: int):
160162
161163 # get the metadata
162164 cu_seqlens = torch .tensor ((0 , ) + spec , device = device ).cumsum (dim = 0 )
163- sed_idx = torch .zeros (cu_seqlens [- 1 ],
165+ seq_idx = torch .zeros (cu_seqlens [- 1 ],
164166 dtype = torch .int32 ,
165167 device = cu_seqlens .device )
166168 for i , (srt , end ) in enumerate (zip (
167169 cu_seqlens ,
168170 cu_seqlens [1 :],
169171 )):
170- sed_idx [srt :end ] = i
172+ seq_idx [srt :end ] = i
171173
172174 # for cont batch
173175 if IND_E is None :
@@ -177,7 +179,7 @@ def end_boundary(n: int):
177179 IND_E = [end_boundary (x + y ) for x , y in zip (IND_S , spec )]
178180
179181 yield ([Y_min [s , IND_S [s ]:IND_E [s ]] for s in range (num_examples )],
180- cu_seqlens , sed_idx .unsqueeze (0 ), (A , dt2 , X2 , B2 , C2 ))
182+ cu_seqlens , seq_idx .unsqueeze (0 ), (A , dt2 , X2 , B2 , C2 ))
181183
182184
183185@pytest .mark .parametrize ("itype" ,
@@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
266268 exhausted : dict = {} # map: eg -> boolean indicating example is exhausted
267269
268270 states = None
269- for Y_min , cu_seqlens , sed_idx , (A , dt , X , B ,
271+ for Y_min , cu_seqlens , seq_idx , (A , dt , X , B ,
270272 C ) in generate_continous_batched_examples (
271273 cases , num_examples , seqlen ,
272274 last_taken , exhausted , n_heads ,
273275 d_head , itype ):
274276
277+ chunk_indices , chunk_offsets = _seq_idx_to_chunk_indices_offsets (
278+ seq_idx , chunk_size )
279+
275280 Y , new_states = mamba_chunk_scan_combined (
276281 X ,
277282 dt ,
@@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
281286 chunk_size ,
282287 D = None ,
283288 cu_seqlens = cu_seqlens ,
284- seq_idx = sed_idx ,
289+ seq_idx = seq_idx ,
290+ chunk_indices = chunk_indices ,
291+ chunk_offsets = chunk_offsets ,
285292 return_varlen_states = True ,
286293 initial_states = states ,
287294 )
0 commit comments