Skip to content

Commit fb3136a

Browse files
Ishrat Badamimeetps
Ishrat Badami
authored andcommitted
adderd learning rate scheduler in training
1 parent 4f9fa49 commit fb3136a

File tree

2 files changed

+25
-18
lines changed

2 files changed

+25
-18
lines changed

lr_scheduling.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, max_iter=30000, power=0.9,):
2+
"""Polynomial decay of learning rate
3+
:param init_lr is base learning rate
4+
:param iter is a current iteration
5+
:param lr_decay_iter how frequently decay occurs, default is 1
6+
:param max_iter is number of maximum iterations
7+
:param power is a polymomial power
8+
9+
"""
10+
if iter % lr_decay_iter or iter > max_iter:
11+
return optimizer
12+
13+
for param_group in optimizer.param_groups:
14+
param_group['lr'] = init_lr*(1 - iter/max_iter)**power
15+
16+
17+
def adjust_learning_rate(optimizer, init_lr, epoch):
18+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
19+
lr = init_lr * (0.1 ** (epoch // 30))
20+
for param_group in optimizer.param_groups:
21+
param_group['lr'] = lr

train.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,7 @@
1414
from ptsemseg.loader import get_loader, get_data_path
1515
from ptsemseg.loss import cross_entropy2d
1616
from ptsemseg.metrics import scores
17-
18-
19-
def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1, max_iter=30000, power=0.9,):
20-
"""Polynomial decay of learning rate
21-
:param init_lr is base learning rate
22-
:param iter is a current iteration
23-
:param lr_decay_iter how frequently decay occurs, default is 1
24-
:param max_iter is number of maximum iterations
25-
:param power is a polymomial power
26-
27-
"""
28-
if iter % lr_decay_iter or iter > max_iter:
29-
return optimizer
30-
31-
for param_group in optimizer.param_groups:
32-
param_group['lr'] = init_lr*(1 - iter/max_iter)**power
33-
return optimizer
34-
17+
from lr_scheduling import *
3518

3619
def train(args):
3720

@@ -74,6 +57,9 @@ def train(args):
7457
images = Variable(images)
7558
labels = Variable(labels)
7659

60+
iter = len(trainloader)*epoch + i
61+
poly_lr_scheduler(optimizer, args.l_rate, iter)
62+
7763
optimizer.zero_grad()
7864
outputs = model(images)
7965

0 commit comments

Comments
 (0)