Skip to content

Commit 438c616

Browse files
committed
profile
Signed-off-by: h-guo18 <[email protected]>
1 parent 8d6a49b commit 438c616

File tree

4 files changed

+22
-3
lines changed

4 files changed

+22
-3
lines changed

examples/speculative_decoding/distill_trainer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
import json
1616
import os
17+
import time
1718

1819
os.environ["TOKENIZERS_PARALLELISM"] = "false"
1920
from abc import abstractmethod
@@ -42,6 +43,7 @@
4243
# Hyperparameters for profiling
4344
LOG_INTERVAL = 100
4445
SAVE_INTERVAL = 20000
46+
TOTAL_STEPS = 500
4547

4648
# Shape and dtype description of the distillation signal
4749
DistillMetadata = dict[str, tuple[torch.Size, torch.dtype]]
@@ -204,6 +206,10 @@ def train(self):
204206
)
205207
for i, batch in enumerate(pbar):
206208
global_step = epoch * len(self.dataloader) + i
209+
if global_step >= TOTAL_STEPS:
210+
break
211+
if global_step == 50:
212+
self.start_time = time.time()
207213
inputs = {k: v.to(self.model.device) for k, v in batch.items()}
208214

209215
# Receive distill messages from teacher
@@ -241,10 +247,18 @@ def train(self):
241247
# Inference Loop
242248
for epoch in range(self.args.epoch):
243249
for i, batch in enumerate(self.dataloader):
250+
global_step = epoch * len(self.dataloader) + i
251+
if global_step >= TOTAL_STEPS:
252+
break
253+
if global_step == 50:
254+
self.start_time = time.time()
244255
inputs = {k: v.to(self.model.device) for k, v in batch.items()}
245256
with torch.inference_mode():
246257
self._send_to_student(self.teacher_step(self.model, inputs))
247258

259+
self.average_step_time = (time.time() - self.start_time) / (TOTAL_STEPS-50)
260+
print(f"Rank {self.rank} average step time: {self.average_step_time}")
261+
248262
self._print_mem_stats()
249263
# Makesure all processes finished before destroy.
250264
dist.barrier()
@@ -321,6 +335,7 @@ def _prepare_student_model(self):
321335
process_group=self.args.student_pgroup,
322336
find_unused_parameters=True,
323337
)
338+
self._print_mem_stats()
324339
return model
325340

326341
@property

examples/speculative_decoding/eagle_config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66
"original_max_position_embeddings": 8192,
77
"rope_type": "llama3"
88
},
9-
"initializer_range": 0.02
9+
"initializer_range": 0.02,
10+
"head_dim": 64
1011
}

examples/speculative_decoding/eagle_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@
4242

4343

4444
def preprocess(examples, tokenizer):
45-
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
45+
if tokenizer.chat_template:
46+
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
47+
else:
48+
tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
4649
new_examples = {
4750
"input_ids": [],
4851
"attention_mask": [],

examples/speculative_decoding/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def main():
7171
parser = argparse.ArgumentParser(description="Multi-GPU distributed two-stage forward example")
7272
parser.add_argument("--model_path", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
7373
parser.add_argument("--student_devices", type=list, default=[0, 1, 2, 3])
74-
parser.add_argument("--teacher_devices", type=list, default=[4, 5])
74+
parser.add_argument("--teacher_devices", type=list, default=[4, 5, 6, 7])
7575
parser.add_argument(
7676
"--data_path", type=str, default="data/magpie_llama3.2_1b_generated/data.cleaned.jsonl"
7777
)

0 commit comments

Comments
 (0)