Skip to content

Commit 668f6cd

Browse files
author
Andrew Gu
committed
Used partial instead of global vars for LR scheduling
ghstack-source-id: 12c4418 Pull Request resolved: #487
1 parent 43584e0 commit 668f6cd

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

torchtitan/lr_scheduling.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,43 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import functools
8+
79
from torch.optim.lr_scheduler import LambdaLR
810
from torchtitan.config_manager import JobConfig
911

10-
# global states for scheduling
11-
# these are needed as LambdaLR does not support argument passing
12-
_warmup_steps = 200
13-
_decay_steps = 0
14-
1512

16-
def linear_warmup_linear_decay(current_step: int) -> float:
13+
def linear_warmup_linear_decay(
14+
warmup_steps: int, decay_steps: int, current_step: int
15+
) -> float:
1716
"""Computes linear warmup followed by linear decay.
1817
Per LambdaLR requirement, this is accomplished by returning
1918
a multiplicative factor to adjust the learning rate to
2019
create the desired schedule.
2120
"""
22-
if current_step < _warmup_steps:
21+
if current_step < warmup_steps:
2322
# linear warmup
2423
# 0-indexed step, hence + 1 adjustments
2524
current_step += 1
26-
curr_adjustment = float(current_step / (_warmup_steps + 1))
25+
curr_adjustment = float(current_step / (warmup_steps + 1))
2726

2827
else:
2928
# linear decay
30-
normalized_step = _decay_steps - (current_step - _warmup_steps)
31-
curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps
29+
normalized_step = decay_steps - (current_step - warmup_steps)
30+
curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps
3231

3332
return curr_adjustment
3433

3534

3635
def get_lr_schedulers(optimizers, job_config: JobConfig):
3736
def _get_lr_scheduler(optimizer):
3837
"""Build a linear warmup and linear decay scheduler"""
39-
global _warmup_steps, _decay_steps
40-
_warmup_steps = int(job_config.training.warmup_steps)
41-
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))
42-
43-
warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)
38+
warmup_steps = int(job_config.training.warmup_steps)
39+
decay_steps = float(max(1, job_config.training.steps - warmup_steps))
40+
lr_lambda = functools.partial(
41+
linear_warmup_linear_decay, warmup_steps, decay_steps
42+
)
43+
warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
4444
return warmup_scheduler
4545

4646
class SchedulersContainer:

0 commit comments

Comments
 (0)