|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import functools |
| 8 | + |
7 | 9 | from torch.optim.lr_scheduler import LambdaLR
|
8 | 10 | from torchtitan.config_manager import JobConfig
|
9 | 11 |
|
10 |
| -# global states for scheduling |
11 |
| -# these are needed as LambdaLR does not support argument passing |
12 |
| -_warmup_steps = 200 |
13 |
| -_decay_steps = 0 |
14 |
| - |
15 | 12 |
|
16 |
| -def linear_warmup_linear_decay(current_step: int) -> float: |
| 13 | +def linear_warmup_linear_decay( |
| 14 | + warmup_steps: int, decay_steps: int, current_step: int |
| 15 | +) -> float: |
17 | 16 | """Computes linear warmup followed by linear decay.
|
18 | 17 | Per LambdaLR requirement, this is accomplished by returning
|
19 | 18 | a multiplicative factor to adjust the learning rate to
|
20 | 19 | create the desired schedule.
|
21 | 20 | """
|
22 |
| - if current_step < _warmup_steps: |
| 21 | + if current_step < warmup_steps: |
23 | 22 | # linear warmup
|
24 | 23 | # 0-indexed step, hence + 1 adjustments
|
25 | 24 | current_step += 1
|
26 |
| - curr_adjustment = float(current_step / (_warmup_steps + 1)) |
| 25 | + curr_adjustment = float(current_step / (warmup_steps + 1)) |
27 | 26 |
|
28 | 27 | else:
|
29 | 28 | # linear decay
|
30 |
| - normalized_step = _decay_steps - (current_step - _warmup_steps) |
31 |
| - curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps |
| 29 | + normalized_step = decay_steps - (current_step - warmup_steps) |
| 30 | + curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps |
32 | 31 |
|
33 | 32 | return curr_adjustment
|
34 | 33 |
|
35 | 34 |
|
36 | 35 | def get_lr_schedulers(optimizers, job_config: JobConfig):
|
37 | 36 | def _get_lr_scheduler(optimizer):
|
38 | 37 | """Build a linear warmup and linear decay scheduler"""
|
39 |
| - global _warmup_steps, _decay_steps |
40 |
| - _warmup_steps = int(job_config.training.warmup_steps) |
41 |
| - _decay_steps = float(max(1, job_config.training.steps - _warmup_steps)) |
42 |
| - |
43 |
| - warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay) |
| 38 | + warmup_steps = int(job_config.training.warmup_steps) |
| 39 | + decay_steps = float(max(1, job_config.training.steps - warmup_steps)) |
| 40 | + lr_lambda = functools.partial( |
| 41 | + linear_warmup_linear_decay, warmup_steps, decay_steps |
| 42 | + ) |
| 43 | + warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) |
44 | 44 | return warmup_scheduler
|
45 | 45 |
|
46 | 46 | class SchedulersContainer:
|
|
0 commit comments