@@ -79,7 +79,6 @@ def varlen_collate_fn(batch):
7979 Packed (input_dict, label) with collapsed batch dimension
8080 """
8181 if len (batch ) == 1 :
82- # Single sample - already packed
8382 input_dict , label = batch [0 ]
8483 return {
8584 "input" : input_dict ["input" ].unsqueeze (0 ), # [1, seq_len]
@@ -89,7 +88,6 @@ def varlen_collate_fn(batch):
8988 "max_k" : input_dict ["max_k" ],
9089 }, label .unsqueeze (0 ) # [1, seq_len]
9190
92- # Multiple samples - pack them together
9391 inputs = []
9492 labels = []
9593 cu_seqlens_list = []
@@ -100,23 +98,17 @@ def varlen_collate_fn(batch):
10098 inputs .append (input_dict ["input" ])
10199 labels .append (label )
102100
103- # Get cu_seqlens from this sample and adjust by offset
104101 cu_seqlens = input_dict ["cu_seq_q" ]
105- # Don't include the last boundary (we'll add it at the end)
106102 cu_seqlens_adjusted = cu_seqlens [:- 1 ] + offset
107103 cu_seqlens_list .append (cu_seqlens_adjusted )
108104
109- # Track maximum sequence length across all samples
110105 max_seqlen = max (max_seqlen , input_dict ["max_q" ])
111106
112- # Update offset for next sample
113107 offset += len (input_dict ["input" ])
114108
115- # Concatenate all inputs and labels
116- packed_input = torch .cat (inputs , dim = 0 ).unsqueeze (0 ) # Shape: [total_tokens]
117- packed_label = torch .cat (labels , dim = 0 ).unsqueeze (0 ) # Shape: [total_tokens]
109+ packed_input = torch .cat (inputs , dim = 0 ).unsqueeze (0 ) # shape: [1, total_tokens]
110+ packed_label = torch .cat (labels , dim = 0 ).unsqueeze (0 ) # shape: [1, total_tokens]
118111
119- # Combine all cu_seqlens and add final boundary
120112 packed_cu_seqlens = torch .cat (
121113 cu_seqlens_list + [torch .tensor ([offset ], dtype = torch .int32 )]
122114 )
@@ -189,7 +181,6 @@ def __iter__(self):
189181
190182 # marks where this current document ends
191183 if self .use_varlen_attn :
192- # if self.use_varlen_attn or self.use_flex_attn:
193184 self ._boundary_buffer .append (len (self ._token_buffer ))
194185
195186 while len (self ._token_buffer ) >= max_buffer_token_len :
@@ -198,19 +189,16 @@ def __iter__(self):
198189 # update tokens to the remaining tokens
199190 self ._token_buffer = self ._token_buffer [max_buffer_token_len :]
200191
201- input = x [:- 1 ] # print device here
192+ input = x [:- 1 ]
202193 label = x [1 :]
203194
204195 if self .use_varlen_attn :
205- # if self.use_varlen_attn or self.use_flex_attn:
206196 boundaries_in_window = [
207197 b for b in self ._boundary_buffer
208198 if b <= max_buffer_token_len
209199 ]
210200
211201 cu_seqlens = torch .tensor (boundaries_in_window , dtype = torch .int32 )
212- # print device here
213-
214202
215203 self ._boundary_buffer = [
216204 b - max_buffer_token_len
0 commit comments