1818from torchtitan .components .tokenizer import BaseTokenizer
1919from torchtitan .config import JobConfig
2020from torchtitan .hf_datasets import DatasetConfig
21- from torchtitan .protocols import train_spec
2221from torchtitan .tools .logging import logger
2322
2423
@@ -68,63 +67,6 @@ def _validate_dataset(
6867 return path , config .loader , config .sample_processor
6968
7069
71- def varlen_collate_fn (batch ):
72- """
73- Custom collate function for variable length attention
74- Collapses batch dimension by packing all samples into one sequence
75-
76- Args:
77- batch: List of (input_dict, label) tuples
78-
79- Returns:
80- packed (input_dict, label) with collapsed batch dimension
81- """
82- if len (batch ) == 1 :
83- input_dict , label = batch [0 ]
84- return {
85- "input" : input_dict ["input" ].unsqueeze (0 ), # [1, seq_len]
86- "cu_seq_q" : input_dict ["cu_seq_q" ],
87- "cu_seq_k" : input_dict ["cu_seq_k" ],
88- "max_q" : input_dict ["max_q" ],
89- "max_k" : input_dict ["max_k" ],
90- }, label .unsqueeze (
91- 0
92- ) # [1, seq_len]
93-
94- inputs = []
95- labels = []
96- cu_seqlens_list = []
97- offset = 0
98- max_seqlen = 0
99-
100- for input_dict , label in batch :
101- inputs .append (input_dict ["input" ])
102- labels .append (label )
103-
104- cu_seqlens = input_dict ["cu_seq_q" ]
105- cu_seqlens_adjusted = cu_seqlens [:- 1 ] + offset
106- cu_seqlens_list .append (cu_seqlens_adjusted )
107-
108- max_seqlen = max (max_seqlen , input_dict ["max_q" ])
109-
110- offset += len (input_dict ["input" ])
111-
112- packed_input = torch .cat (inputs , dim = 0 ).unsqueeze (0 ) # shape: [1, total_tokens]
113- packed_label = torch .cat (labels , dim = 0 ).unsqueeze (0 ) # shape: [1, total_tokens]
114-
115- packed_cu_seqlens = torch .cat (
116- cu_seqlens_list + [torch .tensor ([offset ], dtype = torch .int32 )]
117- )
118-
119- return {
120- "input" : packed_input ,
121- "cu_seq_q" : packed_cu_seqlens ,
122- "cu_seq_k" : packed_cu_seqlens ,
123- "max_q" : max_seqlen ,
124- "max_k" : max_seqlen ,
125- }, packed_label
126-
127-
12870class HuggingFaceTextDataset (IterableDataset , Stateful ):
12971 def __init__ (
13072 self ,
@@ -155,9 +97,6 @@ def __init__(
15597 self ._sample_idx = 0
15698 self ._token_buffer : list [int ] = []
15799
158- self ._boundary_buffer : list [int ] = [0 ]
159- self .use_varlen_attn : bool = False
160-
161100 def _get_data_iter (self ):
162101 # For map-style datasets, resume by skipping to the correct index
163102 # For iterable-style datasets, the underlying iterator already points to the correct index
@@ -182,63 +121,13 @@ def __iter__(self):
182121 self ._token_buffer .extend (sample_tokens )
183122 self ._sample_idx += 1
184123
185- if self .use_varlen_attn :
186- self ._boundary_buffer .append (len (self ._token_buffer ))
187-
188124 while len (self ._token_buffer ) >= max_buffer_token_len :
189125 x = torch .LongTensor (self ._token_buffer [:max_buffer_token_len ])
190-
191126 # update tokens to the remaining tokens
192127 self ._token_buffer = self ._token_buffer [max_buffer_token_len :]
193-
194128 input = x [:- 1 ]
195129 label = x [1 :]
196-
197- if self .use_varlen_attn :
198- boundaries_in_window = [
199- b
200- for b in self ._boundary_buffer
201- if b <= max_buffer_token_len
202- ]
203-
204- cu_seqlens = torch .tensor (
205- boundaries_in_window , dtype = torch .int32
206- )
207-
208- self ._boundary_buffer = [
209- b - max_buffer_token_len
210- for b in self ._boundary_buffer
211- if b > max_buffer_token_len
212- ]
213-
214- if not self ._boundary_buffer or self ._boundary_buffer [0 ] != 0 :
215- self ._boundary_buffer .insert (0 , 0 )
216-
217- cu_seqlens_input = cu_seqlens [cu_seqlens <= len (input )]
218- if cu_seqlens_input [- 1 ] != len (input ):
219- cu_seqlens_input = torch .cat (
220- [
221- cu_seqlens_input ,
222- torch .tensor ([len (input )], dtype = torch .int32 ),
223- ]
224- )
225-
226- seq_lengths = torch .diff (cu_seqlens_input )
227- max_seqlen = (
228- seq_lengths .max ().item ()
229- if len (seq_lengths ) > 0
230- else self .seq_len
231- )
232-
233- yield {
234- "input" : input ,
235- "cu_seq_q" : cu_seqlens_input ,
236- "cu_seq_k" : cu_seqlens_input ,
237- "max_q" : max_seqlen ,
238- "max_k" : max_seqlen ,
239- }, label
240- else :
241- yield {"input" : input }, label
130+ yield {"input" : input }, label
242131
243132 if not self .infinite :
244133 logger .warning (f"Dataset { self .dataset_name } has run out of data" )
@@ -256,7 +145,6 @@ def __iter__(self):
256145
257146 def load_state_dict (self , state_dict ):
258147 self ._token_buffer = state_dict ["token_buffer" ]
259- self ._boundary_buffer = state_dict .get ("boundary_buffer" , [0 ])
260148
261149 if isinstance (self ._data , Dataset ):
262150 self ._sample_idx = state_dict ["sample_idx" ]
@@ -265,10 +153,7 @@ def load_state_dict(self, state_dict):
265153 self ._data .load_state_dict (state_dict ["data" ])
266154
267155 def state_dict (self ):
268- _state_dict = {
269- "token_buffer" : self ._token_buffer ,
270- "boundary_buffer" : self ._boundary_buffer ,
271- }
156+ _state_dict = {"token_buffer" : self ._token_buffer }
272157
273158 if isinstance (self ._data , Dataset ):
274159 _state_dict ["sample_idx" ] = self ._sample_idx
@@ -293,11 +178,6 @@ def build_text_dataloader(
293178 batch_size = job_config .training .local_batch_size
294179 seq_len = job_config .training .seq_len
295180
296- model_args = train_spec .get_train_spec (job_config .model .name ).model_args [
297- job_config .model .flavor
298- ]
299- use_varlen_attn = getattr (model_args , "use_varlen_attn" , False )
300-
301181 hf_ds = HuggingFaceTextDataset (
302182 dataset_name = dataset_name ,
303183 dataset_path = dataset_path ,
@@ -307,16 +187,12 @@ def build_text_dataloader(
307187 dp_world_size = dp_world_size ,
308188 infinite = infinite ,
309189 )
310- hf_ds .use_varlen_attn = use_varlen_attn
311-
312- collate_fn = varlen_collate_fn if use_varlen_attn else None
313190
314191 return ParallelAwareDataloader (
315192 dataset = hf_ds ,
316193 dp_rank = dp_rank ,
317194 dp_world_size = dp_world_size ,
318195 batch_size = batch_size ,
319- collate_fn = collate_fn ,
320196 )
321197
322198
0 commit comments