diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 0169a6ab43e..1945ad1f4a2 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -10,6 +10,7 @@ import utils from coco_utils import get_coco from torch import nn +from torch.optim.lr_scheduler import PolynomialLR from torchvision.transforms import functional as F, InterpolationMode @@ -184,8 +185,8 @@ def main(args): scaler = torch.cuda.amp.GradScaler() if args.amp else None iters_per_epoch = len(data_loader) - main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9 + main_lr_scheduler = PolynomialLR( + optimizer, total_steps=iters_per_epoch * (args.epochs - args.lr_warmup_epochs), power=0.9 ) if args.lr_warmup_epochs > 0: