Skip to content

Commit a7b0671

Browse files
authored
add constant learning rate with custom rule (#3133)
* add constant lr with rules * add constant with rules in TYPE_TO_SCHEDULER_FUNCTION * add constant lr rate with rule * hotfix code quality * fix doc style * change name constant_with_rules to piecewise constant
1 parent be0bfce commit a7b0671

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

src/diffusers/optimization.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class SchedulerType(Enum):
3434
POLYNOMIAL = "polynomial"
3535
CONSTANT = "constant"
3636
CONSTANT_WITH_WARMUP = "constant_with_warmup"
37+
PIECEWISE_CONSTANT = "piecewise_constant"
3738

3839

3940
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
@@ -77,6 +78,48 @@ def lr_lambda(current_step: int):
7778
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
7879

7980

81+
def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
82+
"""
83+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
84+
85+
Args:
86+
optimizer ([`~torch.optim.Optimizer`]):
87+
The optimizer for which to schedule the learning rate.
88+
step_rules (`string`):
89+
The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
90+
if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
91+
steps and multiple 0.005 for the other steps.
92+
last_epoch (`int`, *optional*, defaults to -1):
93+
The index of the last epoch when resuming training.
94+
95+
Return:
96+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
97+
"""
98+
99+
rules_dict = {}
100+
rule_list = step_rules.split(",")
101+
for rule_str in rule_list[:-1]:
102+
value_str, steps_str = rule_str.split(":")
103+
steps = int(steps_str)
104+
value = float(value_str)
105+
rules_dict[steps] = value
106+
last_lr_multiple = float(rule_list[-1])
107+
108+
def create_rules_function(rules_dict, last_lr_multiple):
109+
def rule_func(steps: int) -> float:
110+
sorted_steps = sorted(rules_dict.keys())
111+
for i, sorted_step in enumerate(sorted_steps):
112+
if steps < sorted_step:
113+
return rules_dict[sorted_steps[i]]
114+
return last_lr_multiple
115+
116+
return rule_func
117+
118+
rules_func = create_rules_function(rules_dict, last_lr_multiple)
119+
120+
return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
121+
122+
80123
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
81124
"""
82125
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
@@ -232,12 +275,14 @@ def lr_lambda(current_step: int):
232275
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
233276
SchedulerType.CONSTANT: get_constant_schedule,
234277
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
278+
SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
235279
}
236280

237281

238282
def get_scheduler(
239283
name: Union[str, SchedulerType],
240284
optimizer: Optimizer,
285+
step_rules: Optional[str] = None,
241286
num_warmup_steps: Optional[int] = None,
242287
num_training_steps: Optional[int] = None,
243288
num_cycles: int = 1,
@@ -252,6 +297,8 @@ def get_scheduler(
252297
The name of the scheduler to use.
253298
optimizer (`torch.optim.Optimizer`):
254299
The optimizer that will be used during training.
300+
step_rules (`str`, *optional*):
301+
A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
255302
num_warmup_steps (`int`, *optional*):
256303
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
257304
optional), the function will raise an error if it's unset and the scheduler type requires it.
@@ -270,6 +317,9 @@ def get_scheduler(
270317
if name == SchedulerType.CONSTANT:
271318
return schedule_func(optimizer, last_epoch=last_epoch)
272319

320+
if name == SchedulerType.PIECEWISE_CONSTANT:
321+
return schedule_func(optimizer, rules=step_rules, last_epoch=last_epoch)
322+
273323
# All other schedulers require `num_warmup_steps`
274324
if num_warmup_steps is None:
275325
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")

0 commit comments

Comments
 (0)