@@ -34,6 +34,7 @@ class SchedulerType(Enum):
34
34
POLYNOMIAL = "polynomial"
35
35
CONSTANT = "constant"
36
36
CONSTANT_WITH_WARMUP = "constant_with_warmup"
37
+ PIECEWISE_CONSTANT = "piecewise_constant"
37
38
38
39
39
40
def get_constant_schedule (optimizer : Optimizer , last_epoch : int = - 1 ):
@@ -77,6 +78,48 @@ def lr_lambda(current_step: int):
77
78
return LambdaLR (optimizer , lr_lambda , last_epoch = last_epoch )
78
79
79
80
81
+ def get_piecewise_constant_schedule (optimizer : Optimizer , step_rules : str , last_epoch : int = - 1 ):
82
+ """
83
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
84
+
85
+ Args:
86
+ optimizer ([`~torch.optim.Optimizer`]):
87
+ The optimizer for which to schedule the learning rate.
88
+ step_rules (`string`):
89
+ The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
90
+ if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
91
+ steps and multiple 0.005 for the other steps.
92
+ last_epoch (`int`, *optional*, defaults to -1):
93
+ The index of the last epoch when resuming training.
94
+
95
+ Return:
96
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
97
+ """
98
+
99
+ rules_dict = {}
100
+ rule_list = step_rules .split ("," )
101
+ for rule_str in rule_list [:- 1 ]:
102
+ value_str , steps_str = rule_str .split (":" )
103
+ steps = int (steps_str )
104
+ value = float (value_str )
105
+ rules_dict [steps ] = value
106
+ last_lr_multiple = float (rule_list [- 1 ])
107
+
108
+ def create_rules_function (rules_dict , last_lr_multiple ):
109
+ def rule_func (steps : int ) -> float :
110
+ sorted_steps = sorted (rules_dict .keys ())
111
+ for i , sorted_step in enumerate (sorted_steps ):
112
+ if steps < sorted_step :
113
+ return rules_dict [sorted_steps [i ]]
114
+ return last_lr_multiple
115
+
116
+ return rule_func
117
+
118
+ rules_func = create_rules_function (rules_dict , last_lr_multiple )
119
+
120
+ return LambdaLR (optimizer , rules_func , last_epoch = last_epoch )
121
+
122
+
80
123
def get_linear_schedule_with_warmup (optimizer , num_warmup_steps , num_training_steps , last_epoch = - 1 ):
81
124
"""
82
125
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
@@ -232,12 +275,14 @@ def lr_lambda(current_step: int):
232
275
SchedulerType .POLYNOMIAL : get_polynomial_decay_schedule_with_warmup ,
233
276
SchedulerType .CONSTANT : get_constant_schedule ,
234
277
SchedulerType .CONSTANT_WITH_WARMUP : get_constant_schedule_with_warmup ,
278
+ SchedulerType .PIECEWISE_CONSTANT : get_piecewise_constant_schedule ,
235
279
}
236
280
237
281
238
282
def get_scheduler (
239
283
name : Union [str , SchedulerType ],
240
284
optimizer : Optimizer ,
285
+ step_rules : Optional [str ] = None ,
241
286
num_warmup_steps : Optional [int ] = None ,
242
287
num_training_steps : Optional [int ] = None ,
243
288
num_cycles : int = 1 ,
@@ -252,6 +297,8 @@ def get_scheduler(
252
297
The name of the scheduler to use.
253
298
optimizer (`torch.optim.Optimizer`):
254
299
The optimizer that will be used during training.
300
+ step_rules (`str`, *optional*):
301
+ A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
255
302
num_warmup_steps (`int`, *optional*):
256
303
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
257
304
optional), the function will raise an error if it's unset and the scheduler type requires it.
@@ -270,6 +317,9 @@ def get_scheduler(
270
317
if name == SchedulerType .CONSTANT :
271
318
return schedule_func (optimizer , last_epoch = last_epoch )
272
319
320
+ if name == SchedulerType .PIECEWISE_CONSTANT :
321
+ return schedule_func (optimizer , rules = step_rules , last_epoch = last_epoch )
322
+
273
323
# All other schedulers require `num_warmup_steps`
274
324
if num_warmup_steps is None :
275
325
raise ValueError (f"{ name } requires `num_warmup_steps`, please provide that argument." )
0 commit comments