Skip to content

Commit d3494c0

Browse files
committed
add docstring
Signed-off-by: h-guo18 <[email protected]>
1 parent 77498d4 commit d3494c0

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

examples/speculative_decoding/distill_trainer.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515
import json
1616
import os
17-
import time
1817

1918
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2019
from abc import abstractmethod
@@ -51,8 +50,11 @@
5150

5251
class BaseDistillTrainer:
5352
"""
54-
Base distill trainer with basic training loop and overlapped teacher and student steps.
55-
Initalized and called on every rank.
53+
Base distill trainer.
54+
Designed as a placement for HF trainer for several purposes:
55+
1. Allow separate placement and parallelism for teacher and student.
56+
2. Allow overlapped teacher and student steps.
57+
3. Clean, minimal training loop to reduce compatibility issues.
5658
Args:
5759
rank: rank of the current process
5860
args: arguments
@@ -70,8 +72,9 @@ def __init__(self, rank, args, tokenizer, dataloader):
7072
if rank in args.student_ranks:
7173
self.model = self._prepare_student_model()
7274
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr)
75+
# Same scheduler as HF trainer default
7376
self.scheduler = get_linear_schedule_with_warmup(
74-
self.optimizer, num_warmup_steps=0, num_training_steps=117380
77+
self.optimizer, num_warmup_steps=0, num_training_steps=TOTAL_STEPS
7578
)
7679
else:
7780
self.model = self._prepare_teacher_model()
@@ -208,8 +211,6 @@ def train(self):
208211
global_step = epoch * len(self.dataloader) + i
209212
if global_step >= TOTAL_STEPS:
210213
break
211-
if global_step == 50:
212-
self.start_time = time.time()
213214
inputs = {k: v.to(self.model.device) for k, v in batch.items()}
214215

215216
# Receive distill messages from teacher
@@ -250,15 +251,10 @@ def train(self):
250251
global_step = epoch * len(self.dataloader) + i
251252
if global_step >= TOTAL_STEPS:
252253
break
253-
if global_step == 50:
254-
self.start_time = time.time()
255254
inputs = {k: v.to(self.model.device) for k, v in batch.items()}
256255
with torch.inference_mode():
257256
self._send_to_student(self.teacher_step(self.model, inputs))
258257

259-
self.average_step_time = (time.time() - self.start_time) / (TOTAL_STEPS - 50)
260-
print(f"Rank {self.rank} average step time: {self.average_step_time}")
261-
262258
self._print_mem_stats()
263259
# Makesure all processes finished before destroy.
264260
dist.barrier()

0 commit comments

Comments
 (0)