Skip to content

Commit 6c7ae16

Browse files
tchatonBordacarmocca
authored andcommitted
[bugfix] Resolve LearningRateMonitor + BackboneFinetuning (#7835)
* add test + resolve bug * update changelog * resolve bug * resolve bug * Update pytorch_lightning/callbacks/lr_monitor.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/callbacks/lr_monitor.py Co-authored-by: Jirka Borovec <[email protected]> * update on comments * resolve comments * update * Update tests/callbacks/test_lr_monitor.py Co-authored-by: Carlos Mocholí <[email protected]> * Update pytorch_lightning/callbacks/lr_monitor.py Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> (cherry picked from commit d1becce)
1 parent 933aebc commit 6c7ae16

File tree

3 files changed

+150
-18
lines changed

3 files changed

+150
-18
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
211211
- Fixed a bug where checking `trainer.precision` changed to `'mixed'` when specifying 16 in trainer ([#7825](https://github.com/PyTorchLightning/pytorch-lightning/pull/7825))
212212

213213

214+
- Fixed `LearningRateMonitor` keys not properly setup when running with `BackboneFinetuning` Callback ([#7835](https://github.com/PyTorchLightning/pytorch-lightning/pull/7835))
215+
216+
214217
## [1.3.2] - 2021-05-18
215218

216219
### Changed

pytorch_lightning/callbacks/lr_monitor.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
Monitor and logs learning rate for lr schedulers during training.
2020
2121
"""
22+
from collections import defaultdict
23+
from typing import Any, DefaultDict, Dict, List, Optional, Type
2224

23-
from typing import Dict, List, Optional
25+
from torch.optim.optimizer import Optimizer
2426

2527
from pytorch_lightning.callbacks.base import Callback
2628
from pytorch_lightning.utilities import rank_zero_warn
@@ -53,7 +55,7 @@ class LearningRateMonitor(Callback):
5355
In case of multiple optimizers of same type, they will be named ``Adam``,
5456
``Adam-1`` etc. If a optimizer has multiple parameter groups they will
5557
be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
56-
``name`` keyword in the construction of the learning rate schdulers
58+
``name`` keyword in the construction of the learning rate schedulers
5759
5860
Example::
5961
@@ -138,6 +140,9 @@ def on_train_epoch_start(self, trainer, *args, **kwargs):
138140
def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
139141
latest_stat = {}
140142

143+
names = self._find_names(trainer.lr_schedulers, add_lr_sch_names=False)
144+
self._remap_keys(names)
145+
141146
for name, scheduler in zip(self.lr_sch_names, trainer.lr_schedulers):
142147
if scheduler['interval'] == interval or interval == 'any':
143148
opt = scheduler['scheduler'].optimizer
@@ -146,7 +151,7 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
146151

147152
for i, pg in enumerate(param_groups):
148153
suffix = f'/pg{i + 1}' if len(param_groups) > 1 else ''
149-
lr = self._extract_lr(param_group=pg, name=f'{name}{suffix}')
154+
lr = self._extract_lr(pg, f'{name}{suffix}')
150155
latest_stat.update(lr)
151156
momentum = self._extract_momentum(
152157
param_group=pg, name=f'{name}-momentum{suffix}', use_betas=use_betas
@@ -155,48 +160,70 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
155160

156161
return latest_stat
157162

158-
def _extract_lr(self, param_group, name: str) -> Dict[str, float]:
163+
def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]:
159164
lr = param_group.get('lr')
160165
self.lrs[name].append(lr)
161166
return {name: lr}
162167

163-
def _extract_momentum(self, param_group, name: str, use_betas: bool) -> Dict[str, float]:
168+
def _remap_keys(self, names: List[str], token: str = '/pg1') -> None:
169+
"""
170+
This function is used the remap the keys if param groups for a given optimizer increased.
171+
"""
172+
for new_name in names:
173+
old_name = new_name.replace(token, '')
174+
if token in new_name and old_name in self.lrs:
175+
self.lrs[new_name] = self.lrs.pop(old_name)
176+
elif new_name not in self.lrs:
177+
self.lrs[new_name] = []
178+
179+
def _extract_momentum(self, param_group: Dict[str, Any], name: str, use_betas: bool) -> Dict[str, float]:
164180
if not self.log_momentum:
165181
return {}
166182

167183
momentum = param_group.get('betas')[0] if use_betas else param_group.get('momentum', 0)
168184
self.last_momentum_values[name] = momentum
169185
return {name: momentum}
170186

171-
def _find_names(self, lr_schedulers) -> List[str]:
172-
# Create uniqe names in the case we have multiple of the same learning
173-
# rate schduler + multiple parameter groups
187+
def _add_prefix(
188+
self, name: str, optimizer_cls: Type[Optimizer], seen_optimizer_types: DefaultDict[Type[Optimizer], int]
189+
) -> str:
190+
if optimizer_cls not in seen_optimizer_types:
191+
return name
192+
count = seen_optimizer_types[optimizer_cls]
193+
return name + f'-{count - 1}' if count > 1 else name
194+
195+
def _find_names(self, lr_schedulers: List, add_lr_sch_names: bool = True) -> List[str]:
196+
# Create unique names in the case we have multiple of the same learning
197+
# rate scheduler + multiple parameter groups
174198
names = []
199+
seen_optimizers = []
200+
seen_optimizer_types = defaultdict(int)
175201
for scheduler in lr_schedulers:
176202
sch = scheduler['scheduler']
177203
if scheduler['name'] is not None:
178204
name = scheduler['name']
179205
else:
180-
opt_name = 'lr-' + sch.optimizer.__class__.__name__
181-
i, name = 1, opt_name
206+
name = 'lr-' + sch.optimizer.__class__.__name__
182207

183-
# Multiple schduler of the same type
184-
while True:
185-
if name not in names:
186-
break
187-
i, name = i + 1, f'{opt_name}-{i}'
208+
seen_optimizers.append(sch.optimizer)
209+
optimizer_cls = type(sch.optimizer)
210+
if scheduler['name'] is None:
211+
seen_optimizer_types[optimizer_cls] += 1
188212

189-
# Multiple param groups for the same schduler
213+
# Multiple param groups for the same scheduler
190214
param_groups = sch.optimizer.param_groups
191215

216+
name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)
217+
192218
if len(param_groups) != 1:
193-
for i, pg in enumerate(param_groups):
219+
for i in range(len(param_groups)):
194220
temp = f'{name}/pg{i + 1}'
195221
names.append(temp)
196222
else:
197223
names.append(name)
198224

199-
self.lr_sch_names.append(name)
225+
if add_lr_sch_names:
226+
self.lr_sch_names.append(name)
200227

201228
return names
202229

tests/callbacks/test_lr_monitor.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import pytest
15+
import torch
1516
from torch import optim
1617

1718
import tests.helpers.utils as tutils
1819
from pytorch_lightning import Trainer
1920
from pytorch_lightning.callbacks import LearningRateMonitor
21+
from pytorch_lightning.callbacks.base import Callback
22+
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning
2023
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2124
from tests.helpers import BoringModel
2225
from tests.helpers.datamodules import ClassifDataModule
@@ -278,3 +281,102 @@ def configure_optimizers(self):
278281
)
279282
trainer.fit(TestModel())
280283
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ['my_logging_name']
284+
285+
286+
def test_multiple_optimizers_basefinetuning(tmpdir):
287+
288+
class TestModel(BoringModel):
289+
290+
def __init__(self):
291+
super().__init__()
292+
self.backbone = torch.nn.Sequential(
293+
torch.nn.Linear(32, 32),
294+
torch.nn.Linear(32, 32),
295+
torch.nn.Linear(32, 32),
296+
torch.nn.ReLU(True),
297+
)
298+
self.layer = torch.nn.Linear(32, 2)
299+
300+
def training_step(self, batch, batch_idx, optimizer_idx):
301+
return super().training_step(batch, batch_idx)
302+
303+
def forward(self, x):
304+
return self.layer(self.backbone(x))
305+
306+
def configure_optimizers(self):
307+
parameters = list(filter(lambda p: p.requires_grad, self.parameters()))
308+
opt = optim.Adam(parameters, lr=0.1)
309+
opt_2 = optim.Adam(parameters, lr=0.1)
310+
opt_3 = optim.Adam(parameters, lr=0.1)
311+
optimizers = [opt, opt_2, opt_3]
312+
schedulers = [
313+
optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.5),
314+
optim.lr_scheduler.StepLR(opt_2, step_size=1, gamma=0.5),
315+
]
316+
return optimizers, schedulers
317+
318+
class Check(Callback):
319+
320+
def on_train_epoch_start(self, trainer, pl_module) -> None:
321+
num_param_groups = sum([len(opt.param_groups) for opt in trainer.optimizers])
322+
assert lr_monitor.lr_sch_names == ['lr-Adam', 'lr-Adam-1']
323+
if trainer.current_epoch == 0:
324+
assert num_param_groups == 3
325+
elif trainer.current_epoch == 1:
326+
assert num_param_groups == 4
327+
assert list(lr_monitor.lrs) == ['lr-Adam-1', 'lr-Adam/pg1', 'lr-Adam/pg2']
328+
elif trainer.current_epoch == 2:
329+
assert num_param_groups == 5
330+
assert list(lr_monitor.lrs) == ['lr-Adam/pg1', 'lr-Adam/pg2', 'lr-Adam-1/pg1', 'lr-Adam-1/pg2']
331+
else:
332+
expected = ['lr-Adam/pg1', 'lr-Adam/pg2', 'lr-Adam-1/pg1', 'lr-Adam-1/pg2', 'lr-Adam-1/pg3']
333+
assert list(lr_monitor.lrs) == expected
334+
335+
class TestFinetuning(BackboneFinetuning):
336+
337+
def freeze_before_training(self, pl_module):
338+
self.freeze(pl_module.backbone[0])
339+
self.freeze(pl_module.backbone[1])
340+
self.freeze(pl_module.layer)
341+
342+
def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
343+
"""Called when the epoch begins."""
344+
if epoch == 1 and opt_idx == 0:
345+
self.unfreeze_and_add_param_group(pl_module.backbone[0], optimizer, lr=0.1)
346+
if epoch == 2 and opt_idx == 1:
347+
self.unfreeze_and_add_param_group(pl_module.layer, optimizer, lr=0.1)
348+
349+
if epoch == 3 and opt_idx == 1:
350+
assert len(optimizer.param_groups) == 2
351+
self.unfreeze_and_add_param_group(pl_module.backbone[1], optimizer, lr=0.1)
352+
assert len(optimizer.param_groups) == 3
353+
354+
lr_monitor = LearningRateMonitor()
355+
trainer = Trainer(
356+
default_root_dir=tmpdir,
357+
max_epochs=5,
358+
limit_val_batches=0,
359+
limit_train_batches=2,
360+
callbacks=[TestFinetuning(), lr_monitor, Check()],
361+
progress_bar_refresh_rate=0,
362+
weights_summary=None,
363+
checkpoint_callback=False
364+
)
365+
model = TestModel()
366+
model.training_epoch_end = None
367+
trainer.fit(model)
368+
369+
expected = [0.1, 0.05, 0.025, 0.0125, 0.00625]
370+
assert lr_monitor.lrs['lr-Adam/pg1'] == expected
371+
372+
expected = [0.1, 0.05, 0.025, 0.0125]
373+
assert lr_monitor.lrs['lr-Adam/pg2'] == expected
374+
375+
expected = [0.1, 0.05, 0.025, 0.0125, 0.00625]
376+
assert lr_monitor.lrs['lr-Adam-1/pg1'] == expected
377+
378+
expected = [0.1, 0.05, 0.025]
379+
assert lr_monitor.lrs['lr-Adam-1/pg2'] == expected
380+
381+
expected = [0.1, 0.05]
382+
assert lr_monitor.lrs['lr-Adam-1/pg3'] == expected

0 commit comments

Comments
 (0)