File tree Expand file tree Collapse file tree 4 files changed +36
-6
lines changed
applications/DeepSpeed-Chat/training
step1_supervised_finetuning
step2_reward_model_finetuning Expand file tree Collapse file tree 4 files changed +36
-6
lines changed Original file line number Diff line number Diff line change @@ -161,6 +161,13 @@ def parse_args():
161
161
parser .add_argument ('--only_optimize_lora' ,
162
162
action = 'store_true' ,
163
163
help = 'Only optimize the LoRA parameters.' )
164
+ parser .add_argument (
165
+ "--lora_learning_rate" ,
166
+ type = float ,
167
+ default = 5e-4 ,
168
+ help =
169
+ "Initial LoRA learning rate (after the potential warmup period) to use."
170
+ )
164
171
## Tensorboard logging
165
172
parser .add_argument ('--enable_tensorboard' ,
166
173
action = 'store_true' ,
@@ -274,7 +281,7 @@ def evaluation(model, eval_dataloader):
274
281
275
282
# Split weights in two groups, one with weight decay and the other not.
276
283
optimizer_grouped_parameters = get_optimizer_grouped_parameters (
277
- model , args .weight_decay )
284
+ model , args .weight_decay , args . lora_learning_rate )
278
285
279
286
AdamOptimizer = DeepSpeedCPUAdam if args .offload else FusedAdam
280
287
optimizer = AdamOptimizer (optimizer_grouped_parameters ,
Original file line number Diff line number Diff line change @@ -161,6 +161,13 @@ def parse_args():
161
161
parser .add_argument ('--only_optimize_lora' ,
162
162
action = 'store_true' ,
163
163
help = 'Only optimize the LoRA parameters.' )
164
+ parser .add_argument (
165
+ "--lora_learning_rate" ,
166
+ type = float ,
167
+ default = 5e-4 ,
168
+ help =
169
+ "Initial LoRA learning rate (after the potential warmup period) to use."
170
+ )
164
171
## Tensorboard logging
165
172
parser .add_argument ('--enable_tensorboard' ,
166
173
action = 'store_true' ,
@@ -271,7 +278,7 @@ def evaluation_reward(model, eval_dataloader):
271
278
272
279
# Split weights in two groups, one with weight decay and the other not.
273
280
optimizer_grouped_parameters = get_optimizer_grouped_parameters (
274
- rm_model , args .weight_decay )
281
+ rm_model , args .weight_decay , args . lora_learning_rate )
275
282
276
283
AdamOptimizer = DeepSpeedCPUAdam if args .offload else FusedAdam
277
284
optimizer = AdamOptimizer (optimizer_grouped_parameters ,
Original file line number Diff line number Diff line change @@ -286,6 +286,20 @@ def parse_args():
286
286
parser .add_argument ('--only_optimize_lora' ,
287
287
action = 'store_true' ,
288
288
help = 'Only optimize the LoRA parameters.' )
289
+ parser .add_argument (
290
+ "--actor_lora_learning_rate" ,
291
+ type = float ,
292
+ default = 5e-4 ,
293
+ help =
294
+ "Initial actor LoRA learning rate (after the potential warmup period) to use."
295
+ )
296
+ parser .add_argument (
297
+ "--critic_lora_learning_rate" ,
298
+ type = float ,
299
+ default = 5e-4 ,
300
+ help =
301
+ "Initial critic LoRA learning rate (after the potential warmup period) to use."
302
+ )
289
303
## Make EMA as an optional feature
290
304
parser .add_argument ('--enable_ema' ,
291
305
action = 'store_true' ,
Original file line number Diff line number Diff line change @@ -105,7 +105,8 @@ def _init_actor(self, actor_model_name_or_path):
105
105
# Optimizer
106
106
AdamOptimizer = DeepSpeedCPUAdam if self .args .offload else FusedAdam
107
107
optim_params = get_optimizer_grouped_parameters (
108
- actor_model , self .args .actor_weight_decay )
108
+ actor_model , self .args .actor_weight_decay ,
109
+ self .args .actor_lora_learning_rate )
109
110
optim = AdamOptimizer (optim_params ,
110
111
lr = self .args .actor_learning_rate ,
111
112
betas = (0.9 , 0.95 ))
@@ -231,9 +232,10 @@ def _init_critic(self, critic_model_name_or_path):
231
232
232
233
# Optimizer
233
234
AdamOptimizer = DeepSpeedCPUAdam if self .args .offload else FusedAdam
234
- optim_pararms = get_optimizer_grouped_parameters (
235
- critic_model , self .args .critic_weight_decay )
236
- optim = AdamOptimizer (optim_pararms ,
235
+ optim_params = get_optimizer_grouped_parameters (
236
+ critic_model , self .args .critic_weight_decay ,
237
+ self .args .critic_lora_learning_rate )
238
+ optim = AdamOptimizer (optim_params ,
237
239
lr = self .args .critic_learning_rate ,
238
240
betas = (0.9 , 0.95 ))
239
241
You can’t perform that action at this time.
0 commit comments