2626mto .enable_huggingface_checkpointing ()
2727
2828# Hyperparameters for profiling
29- EPOCHS = 1
30- LOG_INTERVAL = 1
29+ EPOCHS = 10
30+ LOG_INTERVAL = 100
3131SAVE_INTERVAL = 20000
3232# VALIDATE_INTERVAL = 20
3333
@@ -125,6 +125,7 @@ def _recv_from_teacher(self):
125125 req .wait ()
126126
127127 def _get_distill_kwargs (self ):
128+ """Return a copy of received buffer for student training."""
128129 return {k : v .clone ().detach () for k , v in self .student_recv_buffer .items ()}
129130
130131 def _send_to_student (self , teacher_outputs ):
@@ -141,25 +142,6 @@ def _send_to_student(self, teacher_outputs):
141142 for req in reqs :
142143 req .wait ()
143144
144- # def _validate_ar(self, steps=3, osl=20, num_samples=20):
145- # if self.rank != self.args.student_rank:
146- # return
147- # # Load MT-Bench prompts from HuggingFace
148- # ds = load_dataset("HuggingFaceH4/mt_bench_prompts")["train"]
149- # self.model.eval()
150- # self.model.to(self.args.student_device)
151- # ars = validate_ar(
152- # self.model, self.tokenizer, ds, steps, osl, num_samples, self.args.student_device
153- # )
154- # # Print results
155- # avg_ar = sum(ars) / len(ars)
156- # print("\n==== AR Validation Results on MT-Bench ====")
157- # print(f"Number of samples: {len(ars)}")
158- # print(f"Output Sequence Length: {osl}")
159- # print(f"Steps: {steps}")
160- # print(f"Average AR: {avg_ar:.4f}")
161- # self.model.train()
162-
163145 def train (self , dataloader ):
164146 """Main training entrance of the composed model."""
165147 self ._reset_all_mem_stats ()
@@ -174,19 +156,24 @@ def train(self, dataloader):
174156 project = os .environ ["WANDB_PROJECT" ],
175157 config = {"epochs" : EPOCHS , "lr" : self .args .lr , "batch_size" : self .args .batch_size },
176158 ) as run :
177- self .model , self .optimizer = self .load_student_model ()
159+ self .model , self .optimizer , self . scheduler = self .load_student_model ()
178160 self ._init_student_recv_buffer ()
179161 wandb .watch (self .model , log = "all" )
180162
181163 for epoch in range (EPOCHS ):
182- pbar = tqdm (dataloader )
164+ pbar = (
165+ tqdm (dataloader ) if self .rank == self .args .student_ranks [0 ] else dataloader
166+ )
183167 for i , batch in enumerate (pbar ):
184168 global_step = epoch * len (dataloader ) + i
185169 inputs = {k : v .to (self .model .device ) for k , v in batch .items ()}
186170 self ._recv_from_teacher ()
187171 loss , train_acc = self .student_step (inputs , ** self ._get_distill_kwargs ())
188- pbar .set_description (f"Epoch { epoch } Loss:{ loss } Acc:{ train_acc } " )
189172
173+ if self .rank != self .args .student_ranks [0 ]:
174+ continue
175+
176+ pbar .set_description (f"Epoch { epoch } Loss:{ loss } Acc:{ train_acc } " )
190177 if global_step % LOG_INTERVAL == 0 :
191178 run .log (
192179 {
@@ -195,14 +182,10 @@ def train(self, dataloader):
195182 "train_acc_step1" : train_acc [1 ],
196183 "train_acc_step2" : train_acc [2 ],
197184 "train_acc_step3" : train_acc [3 ],
185+ "lr" : self .optimizer .param_groups [0 ]["lr" ],
198186 },
199187 step = global_step ,
200188 )
201-
202- # This is not working for some reason.
203- # if global_step > 0 and global_step % VALIDATE_INTERVAL == 0:
204- # self._validate_ar()
205-
206189 if global_step > 0 and global_step % SAVE_INTERVAL == 0 :
207190 self .save_pretrained (
208191 f"{ self .args .out_path } /epoch_{ epoch } _step_{ global_step } "
0 commit comments