Skip to content

Commit e6b6d34

Browse files
awan-10LeetJoe
authored andcommitted
Fix calculations (include critic model) for performance (deepspeedai#706)
1 parent 6f0ee81 commit e6b6d34

File tree

2 files changed

+93
-52
lines changed
  • applications/DeepSpeed-Chat/training

2 files changed

+93
-52
lines changed

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,8 @@ def main():
540540
print_rank_0(
541541
f'Epoch: {epoch} | Step: {step} | PPO Epoch: {ppo_ep+1} | Actor Loss: {actor_loss_sum/inner_iter} | Critic Loss: {critic_loss_sum/inner_iter} | Unsupervised Loss: {unsup_loss_sum/inner_iter}',
542542
args.global_rank)
543-
print_throughput_step3(rlhf_engine.actor.model, args, e2e_time,
543+
print_throughput_step3(rlhf_engine.actor.model,
544+
rlhf_engine.critic, args, e2e_time,
544545
trainer.generate_time, training_time,
545546
args.global_rank)
546547
average_reward = get_all_reduce_mean(average_reward).item()

applications/DeepSpeed-Chat/training/utils/perf.py

Lines changed: 91 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,26 @@
1010
def print_throughput(hf_model, args, e2e_time, rank=0):
1111
if rank <= 0:
1212
hf_config = hf_model.config
13-
num_layers = getattr(hf_config, "num_hidden_layers",
14-
getattr(hf_config, "n_layer", None))
15-
hidden_size = getattr(hf_config, "hidden_size",
16-
getattr(hf_config, "n_embd", None))
17-
vocab_size = getattr(hf_config, "vocab_size", None)
18-
assert all(
19-
(num_layers, hidden_size, vocab_size)
20-
), "Could not determine number of layers, hidden size, and vocab size of the model"
13+
num_layers, hidden_size, vocab_size = get_hf_configs(hf_config)
2114

2215
gpus_per_model = torch.distributed.get_world_size()
2316
seq_length = args.max_seq_len
2417
batch_size = args.per_device_train_batch_size
2518
samples_per_second = batch_size / e2e_time
2619
checkpoint_activations_factor = 4 if args.gradient_checkpointing else 3
20+
if args.lora_dim > 0:
21+
k = args.lora_dim * 2 / hidden_size
22+
checkpoint_activations_factor -= (1 - k)
23+
2724
hf_model._num_params = sum([
2825
p.ds_numel if hasattr(p, "ds_tensor") else p.numel()
2926
for p in hf_model.parameters()
3027
])
3128
params_in_billions = hf_model._num_params / (1e9)
3229

3330
# Megatron paper's formula to calculate training flops
34-
train_flops_per_iteration = (
35-
24 * checkpoint_activations_factor * batch_size * seq_length *
36-
num_layers *
37-
(hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) +
38-
(vocab_size /
39-
(16.0 * num_layers * hidden_size)))
31+
train_flops_per_iteration = calculate_flops(
32+
checkpoint_activations_factor, batch_size, seq_length, hf_config)
4033

4134
train_tflops = train_flops_per_iteration / (e2e_time * gpus_per_model *
4235
(10**12))
@@ -48,79 +41,126 @@ def print_throughput(hf_model, args, e2e_time, rank=0):
4841

4942

5043
# Enhanced version of the function above that provides calculations and printing for Step 3
51-
def print_throughput_step3(hf_model,
44+
def print_throughput_step3(actor_model,
45+
critic_model,
5246
args,
5347
e2e_time,
5448
gen_exp_time,
5549
train_time,
5650
rank=0):
5751
if rank <= 0:
58-
hf_config = hf_model.config
59-
num_layers = getattr(hf_config, "num_hidden_layers",
60-
getattr(hf_config, "n_layer", None))
61-
hidden_size = getattr(hf_config, "hidden_size",
62-
getattr(hf_config, "n_embd", None))
63-
vocab_size = getattr(hf_config, "vocab_size", None)
64-
assert all(
65-
(num_layers, hidden_size, vocab_size)
66-
), "Could not determine number of layers, hidden size, and vocab size of the model"
52+
# Actor model passed here is a HF model.
53+
actor_hf_config = actor_model.config
54+
# Critic model passed here is a DeepSpeed Engine. The module inside is the Reward model (that wraps a HF model).
55+
critic_hf_config = critic_model.module.config
56+
57+
actor_num_layers, actor_hidden_size, actor_vocab_size = get_hf_configs(
58+
actor_hf_config)
59+
critic_num_layers, critic_hidden_size, critic_vocab_size = get_hf_configs(
60+
critic_hf_config)
6761

6862
gpus_per_model = torch.distributed.get_world_size()
6963
seq_length = args.max_answer_seq_len + args.max_prompt_seq_len
7064
batch_size = args.per_device_generation_batch_size * args.generation_batches * args.ppo_epochs * gpus_per_model * 1 if args.unsupervised_dataset_name is None else 2
7165
samples_per_second = batch_size / e2e_time
72-
checkpoint_activations_factor = 4 if args.actor_gradient_checkpointing else 3
73-
hf_model._num_params = sum([
66+
67+
actor_checkpoint_activations_factor = 4 if args.actor_gradient_checkpointing else 3
68+
critic_checkpoint_activations_factor = 4 if args.critic_gradient_checkpointing else 3
69+
if args.actor_lora_dim > 0:
70+
k = args.actor_lora_dim * 2 / actor_hidden_size
71+
actor_checkpoint_activations_factor -= (1 - k)
72+
if args.critic_lora_dim > 0:
73+
k = args.critic_lora_dim * 2 / critic_hidden_size
74+
critic_checkpoint_activations_factor -= (1 - k)
75+
76+
actor_model._num_params = sum([
7477
p.ds_numel if hasattr(p, "ds_tensor") else p.numel()
75-
for p in hf_model.parameters()
78+
for p in actor_model.parameters()
7679
])
77-
params_in_billions = hf_model._num_params / (1e9)
80+
actor_params_in_billions = actor_model._num_params / (1e9)
81+
82+
critic_model._num_params = sum([
83+
p.ds_numel if hasattr(p, "ds_tensor") else p.numel()
84+
for p in critic_model.parameters()
85+
])
86+
critic_params_in_billions = critic_model._num_params / (1e9)
7887

7988
# Megatron paper's formula to calculate training flops
80-
train_flops_per_iteration = (
81-
24 * checkpoint_activations_factor * batch_size * seq_length *
82-
num_layers *
83-
(hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) +
84-
(vocab_size /
85-
(16.0 * num_layers * hidden_size)))
8689

87-
train_tflops = train_flops_per_iteration / (train_time *
88-
gpus_per_model * (10**12))
90+
actor_train_flops_per_iteration = calculate_flops(
91+
actor_checkpoint_activations_factor, batch_size, seq_length,
92+
actor_hf_config)
93+
critic_train_flops_per_iteration = calculate_flops(
94+
critic_checkpoint_activations_factor, batch_size, seq_length,
95+
critic_hf_config)
96+
97+
total_train_flops = actor_train_flops_per_iteration + critic_train_flops_per_iteration
98+
train_tflops = total_train_flops / (train_time * gpus_per_model *
99+
(10**12))
89100

90101
gen_bs = args.per_device_generation_batch_size * gpus_per_model
91102

92-
# Modified formula for calculating flops in forward pass only
103+
# Modified formula for calculating flops in the forward pass only
93104
gen_flops_per_iteration = (
94-
24 * gen_bs * seq_length * num_layers *
95-
(hidden_size**2)) * (1.0 + (seq_length / (6.0 * hidden_size)) +
96-
(vocab_size /
97-
(16.0 * num_layers * hidden_size)))
105+
24 * gen_bs * seq_length * actor_num_layers *
106+
(actor_hidden_size**2)) * (
107+
1.0 + (seq_length / (6.0 * actor_hidden_size)) +
108+
(actor_vocab_size /
109+
(16.0 * actor_num_layers * actor_hidden_size)))
98110

99111
gen_tflops = gen_flops_per_iteration / (gen_exp_time * gpus_per_model *
100112
(10**12))
101113

102-
if hf_config.torch_dtype == "float16":
114+
if actor_hf_config.torch_dtype == torch.float16:
103115
num_bytes = 2
104-
elif hf_config.torch_dtype == "float32":
116+
elif actor_hf_config.torch_dtype == torch.float32:
105117
num_bytes = 4
106118
else:
107-
num_bytes = 1
119+
num_bytes = -1
108120

109-
gen_bw = (hf_model._num_params *
110-
(num_bytes / 1e9)) / gen_exp_time * args.max_answer_seq_len
121+
pertok_lat = gen_exp_time / args.max_answer_seq_len
122+
gen_bw = 1 / pertok_lat * actor_model._num_params * num_bytes / 1e9
111123

112-
total_flops_per_iteration = train_flops_per_iteration + gen_flops_per_iteration * args.generation_batches
124+
total_flops_per_iteration = total_train_flops + gen_flops_per_iteration * args.generation_batches
113125
total_tflops = total_flops_per_iteration / (e2e_time * gpus_per_model *
114126
(10**12))
115127

116128
print(
117-
f"End-to-End => Latency: {e2e_time:.2f}s, TFLOPs: {total_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Sequence Length: {seq_length}"
129+
f"End-to-End => Latency: {e2e_time:.2f}s, TFLOPs: {total_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Total Seq. Length: {seq_length}"
118130
)
119131
print(
120-
f"Generation => Latency: {gen_exp_time:.2f}s, TFLOPs: {gen_tflops:.2f}, BW: {gen_bw:.2f} GB/sec"
132+
f"Generation => Latency: {gen_exp_time:.2f}s, Per-token Latency {pertok_lat*1000:.2f} ms, TFLOPs: {gen_tflops:.2f}, BW: {gen_bw if num_bytes > 0 else num_bytes:.2f} GB/sec, Answer Seq. Length: {args.max_answer_seq_len}"
121133
)
122134
print(
123135
f"Training => Latency: {train_time:.2f}s, TFLOPs: {train_tflops:.2f}"
124136
)
125-
param_string = f"{params_in_billions:.3f} B" if params_in_billions != 0 else "NA"
126-
print(f"Parameters => {param_string}")
137+
actor_param_string = f"{actor_params_in_billions:.3f} B" if actor_params_in_billions != 0 else "NA"
138+
critic_param_string = f"{critic_params_in_billions:.3f} B" if critic_params_in_billions != 0 else "NA"
139+
print(
140+
f"Actor Model Parameters => {actor_param_string}, Critic Model Parameters => {critic_param_string}"
141+
)
142+
143+
144+
# Helper function to calculate FLOPs using the Megatron-LM paper's formula
145+
def calculate_flops(checkpoint_activations_factor, batch_size, seq_length,
146+
hf_config):
147+
num_layers, hidden_size, vocab_size = get_hf_configs(hf_config)
148+
flops_per_iteration = (24 * checkpoint_activations_factor * batch_size *
149+
seq_length * num_layers * (hidden_size**2)) * (
150+
1.0 + (seq_length / (6.0 * hidden_size)) +
151+
(vocab_size /
152+
(16.0 * num_layers * hidden_size)))
153+
return flops_per_iteration
154+
155+
156+
def get_hf_configs(hf_config):
157+
num_layers = getattr(hf_config, "num_hidden_layers",
158+
getattr(hf_config, "n_layer", None))
159+
hidden_size = getattr(hf_config, "hidden_size",
160+
getattr(hf_config, "n_embd", None))
161+
vocab_size = getattr(hf_config, "vocab_size", None)
162+
assert all(
163+
(num_layers, hidden_size, vocab_size)
164+
), "Could not determine number of layers, hidden size, and vocab size of the model"
165+
166+
return num_layers, hidden_size, vocab_size

0 commit comments

Comments
 (0)