1414# limitations under the License.
1515import json
1616import os
17- import time
1817
1918os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
2019from abc import abstractmethod
5150
5251class BaseDistillTrainer :
5352 """
54- Base distill trainer with basic training loop and overlapped teacher and student steps.
55- Initalized and called on every rank.
53+ Base distill trainer.
54+ Designed as a placement for HF trainer for several purposes:
55+ 1. Allow separate placement and parallelism for teacher and student.
56+ 2. Allow overlapped teacher and student steps.
57+ 3. Clean, minimal training loop to reduce compatibility issues.
5658 Args:
5759 rank: rank of the current process
5860 args: arguments
@@ -70,8 +72,9 @@ def __init__(self, rank, args, tokenizer, dataloader):
7072 if rank in args .student_ranks :
7173 self .model = self ._prepare_student_model ()
7274 self .optimizer = torch .optim .AdamW (self .model .parameters (), lr = self .args .lr )
75+ # Same scheduler as HF trainer default
7376 self .scheduler = get_linear_schedule_with_warmup (
74- self .optimizer , num_warmup_steps = 0 , num_training_steps = 117380
77+ self .optimizer , num_warmup_steps = 0 , num_training_steps = TOTAL_STEPS
7578 )
7679 else :
7780 self .model = self ._prepare_teacher_model ()
@@ -208,8 +211,6 @@ def train(self):
208211 global_step = epoch * len (self .dataloader ) + i
209212 if global_step >= TOTAL_STEPS :
210213 break
211- if global_step == 50 :
212- self .start_time = time .time ()
213214 inputs = {k : v .to (self .model .device ) for k , v in batch .items ()}
214215
215216 # Receive distill messages from teacher
@@ -250,15 +251,10 @@ def train(self):
250251 global_step = epoch * len (self .dataloader ) + i
251252 if global_step >= TOTAL_STEPS :
252253 break
253- if global_step == 50 :
254- self .start_time = time .time ()
255254 inputs = {k : v .to (self .model .device ) for k , v in batch .items ()}
256255 with torch .inference_mode ():
257256 self ._send_to_student (self .teacher_step (self .model , inputs ))
258257
259- self .average_step_time = (time .time () - self .start_time ) / (TOTAL_STEPS - 50 )
260- print (f"Rank { self .rank } average step time: { self .average_step_time } " )
261-
262258 self ._print_mem_stats ()
263259 # Makesure all processes finished before destroy.
264260 dist .barrier ()
0 commit comments