55import  torch .nn .functional  as  F 
66from  einops  import  rearrange , repeat 
77
8+ from  vllm .model_executor .layers .mamba .mamba2_metadata  import  (
9+     _seq_idx_to_chunk_indices_offsets )
810from  vllm .model_executor .layers .mamba .ops .ssd_combined  import  (
911    mamba_chunk_scan_combined )
1012from  vllm .platforms  import  current_platform 
@@ -160,14 +162,14 @@ def end_boundary(n: int):
160162
161163        # get the metadata 
162164        cu_seqlens  =  torch .tensor ((0 , ) +  spec , device = device ).cumsum (dim = 0 )
163-         sed_idx  =  torch .zeros (cu_seqlens [- 1 ],
165+         seq_idx  =  torch .zeros (cu_seqlens [- 1 ],
164166                              dtype = torch .int32 ,
165167                              device = cu_seqlens .device )
166168        for  i , (srt , end ) in  enumerate (zip (
167169                cu_seqlens ,
168170                cu_seqlens [1 :],
169171        )):
170-             sed_idx [srt :end ] =  i 
172+             seq_idx [srt :end ] =  i 
171173
172174        # for cont batch 
173175        if  IND_E  is  None :
@@ -177,7 +179,7 @@ def end_boundary(n: int):
177179        IND_E  =  [end_boundary (x  +  y ) for  x , y  in  zip (IND_S , spec )]
178180
179181        yield  ([Y_min [s , IND_S [s ]:IND_E [s ]] for  s  in  range (num_examples )],
180-                cu_seqlens , sed_idx .unsqueeze (0 ), (A , dt2 , X2 , B2 , C2 ))
182+                cu_seqlens , seq_idx .unsqueeze (0 ), (A , dt2 , X2 , B2 , C2 ))
181183
182184
183185@pytest .mark .parametrize ("itype" , 
@@ -266,12 +268,15 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
266268    exhausted : dict  =  {}  # map: eg -> boolean indicating example is exhausted 
267269
268270    states  =  None 
269-     for  Y_min , cu_seqlens , sed_idx , (A , dt , X , B ,
271+     for  Y_min , cu_seqlens , seq_idx , (A , dt , X , B ,
270272                                     C ) in  generate_continous_batched_examples (
271273                                         cases , num_examples , seqlen ,
272274                                         last_taken , exhausted , n_heads ,
273275                                         d_head , itype ):
274276
277+         chunk_indices , chunk_offsets  =  _seq_idx_to_chunk_indices_offsets (
278+             seq_idx , chunk_size )
279+ 
275280        Y , new_states  =  mamba_chunk_scan_combined (
276281            X ,
277282            dt ,
@@ -281,7 +286,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
281286            chunk_size ,
282287            D = None ,
283288            cu_seqlens = cu_seqlens ,
284-             seq_idx = sed_idx ,
289+             seq_idx = seq_idx ,
290+             chunk_indices = chunk_indices ,
291+             chunk_offsets = chunk_offsets ,
285292            return_varlen_states = True ,
286293            initial_states = states ,
287294        )
0 commit comments