Skip to content

Commit a2b4c65

Browse files
authored
Warmup schedulers in References (#4411)
* Warmup on Classficiation references. * Adjust epochs for cosine. * Warmup on Segmentation references. * Warmup on Video classification references. * Adding support of both types of warmup in segmentation. * Use LinearLR in detection. * Fix deprecation warning.
1 parent 9275cc6 commit a2b4c65

File tree

7 files changed

+78
-71
lines changed

7 files changed

+78
-71
lines changed

references/classification/train.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,25 @@ def main(args):
208208
opt_level=args.apex_opt_level
209209
)
210210

211-
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
211+
args.lr_scheduler = args.lr_scheduler.lower()
212+
if args.lr_scheduler == 'steplr':
213+
main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
214+
elif args.lr_scheduler == 'cosineannealinglr':
215+
main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
216+
T_max=args.epochs - args.lr_warmup_epochs)
217+
else:
218+
raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR and CosineAnnealingLR "
219+
"are supported.".format(args.lr_scheduler))
220+
221+
if args.lr_warmup_epochs > 0:
222+
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
223+
optimizer,
224+
schedulers=[torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay,
225+
total_iters=args.lr_warmup_epochs), main_lr_scheduler],
226+
milestones=[args.lr_warmup_epochs]
227+
)
228+
else:
229+
lr_scheduler = main_lr_scheduler
212230

213231
model_without_ddp = model
214232
if args.distributed:
@@ -287,6 +305,9 @@ def get_args_parser(add_help=True):
287305
dest='label_smoothing')
288306
parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)')
289307
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
308+
parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)')
309+
parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)')
310+
parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr')
290311
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
291312
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
292313
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')

references/detection/engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
2121
warmup_factor = 1. / 1000
2222
warmup_iters = min(1000, len(data_loader) - 1)
2323

24-
lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
24+
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=warmup_factor,
25+
total_iters=warmup_iters)
2526

2627
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
2728
images = list(image.to(device) for image in images)

references/detection/utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -207,17 +207,6 @@ def collate_fn(batch):
207207
return tuple(zip(*batch))
208208

209209

210-
def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
211-
212-
def f(x):
213-
if x >= warmup_iters:
214-
return 1
215-
alpha = float(x) / warmup_iters
216-
return warmup_factor * (1 - alpha) + alpha
217-
218-
return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
219-
220-
221210
def mkdir(path):
222211
try:
223212
os.makedirs(path)

references/segmentation/train.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,30 @@ def main(args):
133133
params_to_optimize,
134134
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
135135

136-
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
136+
iters_per_epoch = len(data_loader)
137+
main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
137138
optimizer,
138-
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)
139+
lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9)
140+
141+
if args.lr_warmup_epochs > 0:
142+
warmup_iters = iters_per_epoch * args.lr_warmup_epochs
143+
args.lr_warmup_method = args.lr_warmup_method.lower()
144+
if args.lr_warmup_method == 'linear':
145+
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay,
146+
total_iters=warmup_iters)
147+
elif args.lr_warmup_method == 'constant':
148+
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay,
149+
total_iters=warmup_iters)
150+
else:
151+
raise RuntimeError("Invalid warmup lr method '{}'. Only linear and constant "
152+
"are supported.".format(args.lr_warmup_method))
153+
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
154+
optimizer,
155+
schedulers=[warmup_lr_scheduler, main_lr_scheduler],
156+
milestones=[warmup_iters]
157+
)
158+
else:
159+
lr_scheduler = main_lr_scheduler
139160

140161
if args.resume:
141162
checkpoint = torch.load(args.resume, map_location='cpu')
@@ -197,6 +218,9 @@ def get_args_parser(add_help=True):
197218
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
198219
metavar='W', help='weight decay (default: 1e-4)',
199220
dest='weight_decay')
221+
parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)')
222+
parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)')
223+
parser.add_argument('--lr-warmup-decay', default=0.01, type=int, help='the decay for lr')
200224
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
201225
parser.add_argument('--output-dir', default='.', help='path where to save')
202226
parser.add_argument('--resume', default='', help='resume from checkpoint')

references/segmentation/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, min_size, max_size=None):
3737
def __call__(self, image, target):
3838
size = random.randint(self.min_size, self.max_size)
3939
image = F.resize(image, size)
40-
target = F.resize(target, size, interpolation=Image.NEAREST)
40+
target = F.resize(target, size, interpolation=T.InterpolationMode.NEAREST)
4141
return image, target
4242

4343

references/video_classification/scheduler.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

references/video_classification/train.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
import presets
1313
import utils
1414

15-
from scheduler import WarmupMultiStepLR
16-
1715
try:
1816
from apex import amp
1917
except ImportError:
@@ -202,11 +200,30 @@ def main(args):
202200

203201
# convert scheduler to be per iteration, not per epoch, for warmup that lasts
204202
# between different epochs
205-
warmup_iters = args.lr_warmup_epochs * len(data_loader)
206-
lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
207-
lr_scheduler = WarmupMultiStepLR(
208-
optimizer, milestones=lr_milestones, gamma=args.lr_gamma,
209-
warmup_iters=warmup_iters, warmup_factor=1e-5)
203+
iters_per_epoch = len(data_loader)
204+
lr_milestones = [iters_per_epoch * (m - args.lr_warmup_epochs) for m in args.lr_milestones]
205+
main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma)
206+
207+
if args.lr_warmup_epochs > 0:
208+
warmup_iters = iters_per_epoch * args.lr_warmup_epochs
209+
args.lr_warmup_method = args.lr_warmup_method.lower()
210+
if args.lr_warmup_method == 'linear':
211+
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay,
212+
total_iters=warmup_iters)
213+
elif args.lr_warmup_method == 'constant':
214+
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay,
215+
total_iters=warmup_iters)
216+
else:
217+
raise RuntimeError("Invalid warmup lr method '{}'. Only linear and constant "
218+
"are supported.".format(args.lr_warmup_method))
219+
220+
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
221+
optimizer,
222+
schedulers=[warmup_lr_scheduler, main_lr_scheduler],
223+
milestones=[warmup_iters]
224+
)
225+
else:
226+
lr_scheduler = main_lr_scheduler
210227

211228
model_without_ddp = model
212229
if args.distributed:
@@ -277,7 +294,9 @@ def parse_args():
277294
dest='weight_decay')
278295
parser.add_argument('--lr-milestones', nargs='+', default=[20, 30, 40], type=int, help='decrease lr on milestones')
279296
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
280-
parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='number of warmup epochs')
297+
parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='the number of epochs to warmup (default: 10)')
298+
parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)')
299+
parser.add_argument('--lr-warmup-decay', default=0.001, type=int, help='the decay for lr')
281300
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
282301
parser.add_argument('--output-dir', default='.', help='path where to save')
283302
parser.add_argument('--resume', default='', help='resume from checkpoint')

0 commit comments

Comments
 (0)