Skip to content

Commit 6f7f121

Browse files
lekurileLeetJoe
authored andcommitted
Add LoRA LR for DS Chat steps 1-3 (deepspeedai#685)
This PR adds an explicit LoRA learning rate argument for DS Chat steps 1 through 3. Step 1: - lora_learning_rate Step 2: - lora_learning_rate Step 3: - actor_lora_learning_rate - critic_lora_learning_rate
1 parent bcbc838 commit 6f7f121

File tree

4 files changed

+36
-6
lines changed

4 files changed

+36
-6
lines changed

applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ def parse_args():
161161
parser.add_argument('--only_optimize_lora',
162162
action='store_true',
163163
help='Only optimize the LoRA parameters.')
164+
parser.add_argument(
165+
"--lora_learning_rate",
166+
type=float,
167+
default=5e-4,
168+
help=
169+
"Initial LoRA learning rate (after the potential warmup period) to use."
170+
)
164171
## Tensorboard logging
165172
parser.add_argument('--enable_tensorboard',
166173
action='store_true',
@@ -274,7 +281,7 @@ def evaluation(model, eval_dataloader):
274281

275282
# Split weights in two groups, one with weight decay and the other not.
276283
optimizer_grouped_parameters = get_optimizer_grouped_parameters(
277-
model, args.weight_decay)
284+
model, args.weight_decay, args.lora_learning_rate)
278285

279286
AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam
280287
optimizer = AdamOptimizer(optimizer_grouped_parameters,

applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ def parse_args():
161161
parser.add_argument('--only_optimize_lora',
162162
action='store_true',
163163
help='Only optimize the LoRA parameters.')
164+
parser.add_argument(
165+
"--lora_learning_rate",
166+
type=float,
167+
default=5e-4,
168+
help=
169+
"Initial LoRA learning rate (after the potential warmup period) to use."
170+
)
164171
## Tensorboard logging
165172
parser.add_argument('--enable_tensorboard',
166173
action='store_true',
@@ -271,7 +278,7 @@ def evaluation_reward(model, eval_dataloader):
271278

272279
# Split weights in two groups, one with weight decay and the other not.
273280
optimizer_grouped_parameters = get_optimizer_grouped_parameters(
274-
rm_model, args.weight_decay)
281+
rm_model, args.weight_decay, args.lora_learning_rate)
275282

276283
AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam
277284
optimizer = AdamOptimizer(optimizer_grouped_parameters,

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,20 @@ def parse_args():
286286
parser.add_argument('--only_optimize_lora',
287287
action='store_true',
288288
help='Only optimize the LoRA parameters.')
289+
parser.add_argument(
290+
"--actor_lora_learning_rate",
291+
type=float,
292+
default=5e-4,
293+
help=
294+
"Initial actor LoRA learning rate (after the potential warmup period) to use."
295+
)
296+
parser.add_argument(
297+
"--critic_lora_learning_rate",
298+
type=float,
299+
default=5e-4,
300+
help=
301+
"Initial critic LoRA learning rate (after the potential warmup period) to use."
302+
)
289303
## Make EMA as an optional feature
290304
parser.add_argument('--enable_ema',
291305
action='store_true',

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def _init_actor(self, actor_model_name_or_path):
105105
# Optimizer
106106
AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam
107107
optim_params = get_optimizer_grouped_parameters(
108-
actor_model, self.args.actor_weight_decay)
108+
actor_model, self.args.actor_weight_decay,
109+
self.args.actor_lora_learning_rate)
109110
optim = AdamOptimizer(optim_params,
110111
lr=self.args.actor_learning_rate,
111112
betas=(0.9, 0.95))
@@ -231,9 +232,10 @@ def _init_critic(self, critic_model_name_or_path):
231232

232233
# Optimizer
233234
AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam
234-
optim_pararms = get_optimizer_grouped_parameters(
235-
critic_model, self.args.critic_weight_decay)
236-
optim = AdamOptimizer(optim_pararms,
235+
optim_params = get_optimizer_grouped_parameters(
236+
critic_model, self.args.critic_weight_decay,
237+
self.args.critic_lora_learning_rate)
238+
optim = AdamOptimizer(optim_params,
237239
lr=self.args.critic_learning_rate,
238240
betas=(0.9, 0.95))
239241

0 commit comments

Comments
 (0)