Skip to content

Commit de45af4

Browse files
authored
Allow setting num_cycles for cosine_with_restarts lr scheduler (#3606)
Expose num_cycles kwarg of get_schedule() through args.lr_num_cycles.
1 parent b95cbdf commit de45af4

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

examples/textual_inversion/textual_inversion.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,12 @@ def parse_args():
285285
parser.add_argument(
286286
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
287287
)
288+
parser.add_argument(
289+
"--lr_num_cycles",
290+
type=int,
291+
default=1,
292+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
293+
)
288294
parser.add_argument(
289295
"--dataloader_num_workers",
290296
type=int,
@@ -739,6 +745,7 @@ def main():
739745
optimizer=optimizer,
740746
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
741747
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
748+
num_cycles=args.lr_num_cycles * args.gradient_accumulation_steps,
742749
)
743750

744751
# Prepare everything with our `accelerator`.

0 commit comments

Comments
 (0)