diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index b66d117e90be..0bf76c166835 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -285,6 +285,12 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -739,6 +745,7 @@ def main(): optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles * args.gradient_accumulation_steps, ) # Prepare everything with our `accelerator`.