Skip to content

Commit 6f0ee81

Browse files
awan-10molly-smithlekurileyaozhewei
authored andcommitted
add timers and performance metrics (deepspeedai#688)
Co-authored-by: Ammar Ahmad Awan <[email protected]> Co-authored-by: Molly Smith <[email protected]> Co-authored-by: Lev Kurilenko <[email protected]> Co-authored-by: Zhewei Yao <[email protected]>
1 parent 40ed715 commit 6f0ee81

File tree

5 files changed

+162
-9
lines changed

5 files changed

+162
-9
lines changed

applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from utils.ds_utils import get_train_ds_config
3030
from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
3131
from utils.model.model_utils import create_hf_model
32+
from utils.perf import print_throughput
3233

3334

3435
def parse_args():
@@ -321,7 +322,9 @@ def evaluation(model, eval_dataloader):
321322
f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}",
322323
args.global_rank)
323324
model.train()
325+
import time
324326
for step, batch in enumerate(train_dataloader):
327+
start = time.time()
325328
batch = to_device(batch, device)
326329
outputs = model(**batch, use_cache=False)
327330
loss = outputs.loss
@@ -331,6 +334,10 @@ def evaluation(model, eval_dataloader):
331334
)
332335
model.backward(loss)
333336
model.step()
337+
end = time.time()
338+
if torch.distributed.get_rank() == 0:
339+
print_throughput(model.model, args, end - start,
340+
args.global_rank)
334341

335342
# Evaluate perplexity on the validation set.
336343
print_rank_0(

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import argparse
2020
import os
2121
import random
22+
import time
2223
import torch
2324
from torch.utils.data import DataLoader, RandomSampler
2425
from torch.utils.data.distributed import DistributedSampler
@@ -42,6 +43,7 @@
4243
from utils.data.data_utils import create_prompt_dataset, MiniDataset, DataCollatorRLHF, get_unsupervised_data
4344
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, moving_average, save_zero_three_model, load_hf_tokenizer
4445
from utils.module.lora import convert_lora_to_linear_layer
46+
from utils.perf import print_throughput_step3
4547

4648
writer = None
4749

@@ -478,13 +480,9 @@ def main():
478480
args.global_rank)
479481
for step, (batch_prompt, batch_unsupervised) in enumerate(
480482
zip(prompt_train_dataloader, unsupervised_train_dataloader)):
483+
481484
batch_prompt = to_device(batch_prompt, device)
482-
if batch_unsupervised is not None:
483-
batch_unsupervised = to_device(batch_unsupervised, device)
484-
unsup_dataset = unsup_mini_dataset.add(batch_unsupervised)
485-
else:
486-
unsup_dataset = unsup_mini_dataset.add(
487-
[[None] * args.per_device_generation_batch_size])
485+
488486
# prompts = batch_prompt['prompt']
489487
# length = prompts.size(-1)
490488
# if length > args.max_prompt_seq_len:
@@ -494,6 +492,15 @@ def main():
494492
out = trainer.generate_experience(batch_prompt['prompt'],
495493
batch_prompt['prompt_att_mask'],
496494
step)
495+
496+
training_start = time.time()
497+
if batch_unsupervised is not None:
498+
batch_unsupervised = to_device(batch_unsupervised, device)
499+
unsup_dataset = unsup_mini_dataset.add(batch_unsupervised)
500+
else:
501+
unsup_dataset = unsup_mini_dataset.add(
502+
[[None] * args.per_device_generation_batch_size])
503+
497504
exp_dataset = exp_mini_dataset.add(out)
498505

499506
if exp_dataset is not None:
@@ -526,16 +533,24 @@ def main():
526533
random.shuffle(exp_dataset)
527534
random.shuffle(unsup_dataset)
528535

536+
end = time.time()
537+
training_time = end - training_start
538+
e2e_time = training_time + trainer.generate_time * args.generation_batches # it is an approximation, we did not include, e.g., rw forward time etc
539+
529540
print_rank_0(
530-
f'epoch: {epoch}|step: {step}|ppo_ep: {ppo_ep+1}|act_loss: {actor_loss_sum/inner_iter}|cri_loss: {critic_loss_sum/inner_iter}|unsuper_loss: {unsup_loss_sum/inner_iter}',
541+
f'Epoch: {epoch} | Step: {step} | PPO Epoch: {ppo_ep+1} | Actor Loss: {actor_loss_sum/inner_iter} | Critic Loss: {critic_loss_sum/inner_iter} | Unsupervised Loss: {unsup_loss_sum/inner_iter}',
531542
args.global_rank)
543+
print_throughput_step3(rlhf_engine.actor.model, args, e2e_time,
544+
trainer.generate_time, training_time,
545+
args.global_rank)
532546
average_reward = get_all_reduce_mean(average_reward).item()
533547
print_rank_0(
534-
f"average reward score: {average_reward/inner_iter}",
548+
f"Average reward score: {average_reward/inner_iter}",
535549
args.global_rank)
536550
print_rank_0(
537551
"-------------------------------------------------------------------------------------",
538552
args.global_rank)
553+
539554
if args.enable_tensorboard and torch.distributed.get_rank(
540555
) == 0:
541556
writer.add_scalar('reward',

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.nn.functional as F
77
import sys
88
import os
9+
import time
910
import deepspeed
1011
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
1112

@@ -65,6 +66,7 @@ def __init__(self, rlhf_engine, args):
6566
self.cliprange_value = 0.2
6667
self.gamma = 1.0
6768
self.lam = 0.95
69+
self.generate_time = 0.0
6870

6971
def _generate_sequence(self, prompts, mask, step):
7072

@@ -116,7 +118,9 @@ def _generate_sequence(self, prompts, mask, step):
116118

117119
def generate_experience(self, prompts, mask, step):
118120
self.eval()
121+
generate_start = time.time()
119122
seq = self._generate_sequence(prompts, mask, step)
123+
generate_end = time.time()
120124
self.train()
121125

122126
pad_token_id = self.tokenizer.pad_token_id
@@ -134,6 +138,8 @@ def generate_experience(self, prompts, mask, step):
134138
logits = output.logits
135139
logits_ref = output_ref.logits
136140

141+
self.generate_time = generate_end - generate_start
142+
137143
return {
138144
'prompts': prompts,
139145
'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]),

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/training_scripts/opt/single_node/run_1.3b.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ deepspeed --master_port 12346 main.py \
5858
--critic_zero_stage $CRITIC_ZERO_STAGE \
5959
--enable_ema \
6060
--output_dir $OUTPUT \
61-
--print_answers \
6261
--enable_tensorboard \
6362
--tensorboard_path $OUTPUT \
6463
&> $OUTPUT/training.log
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import torch
7+
8+
9+
# This function can be used to print throughput for Step 1 and 2 only
10+
def print_throughput(hf_model, args, e2e_time, rank=0):
11+
if rank <= 0:
12+
hf_config = hf_model.config
13+
num_layers = getattr(hf_config, "num_hidden_layers",
14+
getattr(hf_config, "n_layer", None))
15+
hidden_size = getattr(hf_config, "hidden_size",
16+
getattr(hf_config, "n_embd", None))
17+
vocab_size = getattr(hf_config, "vocab_size", None)
18+
assert all(
19+
(num_layers, hidden_size, vocab_size)
20+
), "Could not determine number of layers, hidden size, and vocab size of the model"
21+
22+
gpus_per_model = torch.distributed.get_world_size()
23+
seq_length = args.max_seq_len
24+
batch_size = args.per_device_train_batch_size
25+
samples_per_second = batch_size / e2e_time
26+
checkpoint_activations_factor = 4 if args.gradient_checkpointing else 3
27+
hf_model._num_params = sum([
28+
p.ds_numel if hasattr(p, "ds_tensor") else p.numel()
29+
for p in hf_model.parameters()
30+
])
31+
params_in_billions = hf_model._num_params / (1e9)
32+
33+
# Megatron paper's formula to calculate training flops
34+
train_flops_per_iteration = (
35+
24 * checkpoint_activations_factor * batch_size * seq_length *
36+
num_layers *
37+
(hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) +
38+
(vocab_size /
39+
(16.0 * num_layers * hidden_size)))
40+
41+
train_tflops = train_flops_per_iteration / (e2e_time * gpus_per_model *
42+
(10**12))
43+
44+
param_string = f"{params_in_billions:.3f} B" if params_in_billions != 0 else "NA"
45+
print(
46+
f"Model Parameters: {param_string}, Latency: {e2e_time:.2f}s, TFLOPs: {train_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Sequence Length: {seq_length}"
47+
)
48+
49+
50+
# Enhanced version of the function above that provides calculations and printing for Step 3
51+
def print_throughput_step3(hf_model,
52+
args,
53+
e2e_time,
54+
gen_exp_time,
55+
train_time,
56+
rank=0):
57+
if rank <= 0:
58+
hf_config = hf_model.config
59+
num_layers = getattr(hf_config, "num_hidden_layers",
60+
getattr(hf_config, "n_layer", None))
61+
hidden_size = getattr(hf_config, "hidden_size",
62+
getattr(hf_config, "n_embd", None))
63+
vocab_size = getattr(hf_config, "vocab_size", None)
64+
assert all(
65+
(num_layers, hidden_size, vocab_size)
66+
), "Could not determine number of layers, hidden size, and vocab size of the model"
67+
68+
gpus_per_model = torch.distributed.get_world_size()
69+
seq_length = args.max_answer_seq_len + args.max_prompt_seq_len
70+
batch_size = args.per_device_generation_batch_size * args.generation_batches * args.ppo_epochs * gpus_per_model * 1 if args.unsupervised_dataset_name is None else 2
71+
samples_per_second = batch_size / e2e_time
72+
checkpoint_activations_factor = 4 if args.actor_gradient_checkpointing else 3
73+
hf_model._num_params = sum([
74+
p.ds_numel if hasattr(p, "ds_tensor") else p.numel()
75+
for p in hf_model.parameters()
76+
])
77+
params_in_billions = hf_model._num_params / (1e9)
78+
79+
# Megatron paper's formula to calculate training flops
80+
train_flops_per_iteration = (
81+
24 * checkpoint_activations_factor * batch_size * seq_length *
82+
num_layers *
83+
(hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) +
84+
(vocab_size /
85+
(16.0 * num_layers * hidden_size)))
86+
87+
train_tflops = train_flops_per_iteration / (train_time *
88+
gpus_per_model * (10**12))
89+
90+
gen_bs = args.per_device_generation_batch_size * gpus_per_model
91+
92+
# Modified formula for calculating flops in forward pass only
93+
gen_flops_per_iteration = (
94+
24 * gen_bs * seq_length * num_layers *
95+
(hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) +
96+
(vocab_size /
97+
(16.0 * num_layers * hidden_size)))
98+
99+
gen_tflops = gen_flops_per_iteration / (gen_exp_time * gpus_per_model *
100+
(10**12))
101+
102+
if hf_config.torch_dtype == "float16":
103+
num_bytes = 2
104+
elif hf_config.torch_dtype == "float32":
105+
num_bytes = 4
106+
else:
107+
num_bytes = 1
108+
109+
gen_bw = (hf_model._num_params *
110+
(num_bytes / 1e9)) / gen_exp_time * args.max_answer_seq_len
111+
112+
total_flops_per_iteration = train_flops_per_iteration + gen_flops_per_iteration * args.generation_batches
113+
total_tflops = total_flops_per_iteration / (e2e_time * gpus_per_model *
114+
(10**12))
115+
116+
print(
117+
f"End-to-End => Latency: {e2e_time:.2f}s, TFLOPs: {total_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Sequence Length: {seq_length}"
118+
)
119+
print(
120+
f"Generation => Latency: {gen_exp_time:.2f}s, TFLOPs: {gen_tflops:.2f}, BW: {gen_bw:.2f} GB/sec"
121+
)
122+
print(
123+
f"Training => Latency: {train_time:.2f}s, TFLOPs: {train_tflops:.2f}"
124+
)
125+
param_string = f"{params_in_billions:.3f} B" if params_in_billions != 0 else "NA"
126+
print(f"Parameters => {param_string}")

0 commit comments

Comments
 (0)