1919Monitor and logs learning rate for lr schedulers during training.
2020
2121"""
22+ from collections import defaultdict
23+ from typing import Any , DefaultDict , Dict , List , Optional , Type
2224
23- from typing import Dict , List , Optional
25+ from torch . optim . optimizer import Optimizer
2426
2527from pytorch_lightning .callbacks .base import Callback
2628from pytorch_lightning .utilities import rank_zero_warn
@@ -53,7 +55,7 @@ class LearningRateMonitor(Callback):
5355 In case of multiple optimizers of same type, they will be named ``Adam``,
5456 ``Adam-1`` etc. If a optimizer has multiple parameter groups they will
5557 be named ``Adam/pg1``, ``Adam/pg2`` etc. To control naming, pass in a
56- ``name`` keyword in the construction of the learning rate schdulers
58+ ``name`` keyword in the construction of the learning rate schedulers
5759
5860 Example::
5961
@@ -138,6 +140,9 @@ def on_train_epoch_start(self, trainer, *args, **kwargs):
138140 def _extract_stats (self , trainer , interval : str ) -> Dict [str , float ]:
139141 latest_stat = {}
140142
143+ names = self ._find_names (trainer .lr_schedulers , add_lr_sch_names = False )
144+ self ._remap_keys (names )
145+
141146 for name , scheduler in zip (self .lr_sch_names , trainer .lr_schedulers ):
142147 if scheduler ['interval' ] == interval or interval == 'any' :
143148 opt = scheduler ['scheduler' ].optimizer
@@ -146,7 +151,7 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
146151
147152 for i , pg in enumerate (param_groups ):
148153 suffix = f'/pg{ i + 1 } ' if len (param_groups ) > 1 else ''
149- lr = self ._extract_lr (param_group = pg , name = f'{ name } { suffix } ' )
154+ lr = self ._extract_lr (pg , f'{ name } { suffix } ' )
150155 latest_stat .update (lr )
151156 momentum = self ._extract_momentum (
152157 param_group = pg , name = f'{ name } -momentum{ suffix } ' , use_betas = use_betas
@@ -155,48 +160,70 @@ def _extract_stats(self, trainer, interval: str) -> Dict[str, float]:
155160
156161 return latest_stat
157162
158- def _extract_lr (self , param_group , name : str ) -> Dict [str , float ]:
163+ def _extract_lr (self , param_group : Dict [ str , Any ], name : str ) -> Dict [str , Any ]:
159164 lr = param_group .get ('lr' )
160165 self .lrs [name ].append (lr )
161166 return {name : lr }
162167
163- def _extract_momentum (self , param_group , name : str , use_betas : bool ) -> Dict [str , float ]:
168+ def _remap_keys (self , names : List [str ], token : str = '/pg1' ) -> None :
169+ """
170+ This function is used the remap the keys if param groups for a given optimizer increased.
171+ """
172+ for new_name in names :
173+ old_name = new_name .replace (token , '' )
174+ if token in new_name and old_name in self .lrs :
175+ self .lrs [new_name ] = self .lrs .pop (old_name )
176+ elif new_name not in self .lrs :
177+ self .lrs [new_name ] = []
178+
179+ def _extract_momentum (self , param_group : Dict [str , Any ], name : str , use_betas : bool ) -> Dict [str , float ]:
164180 if not self .log_momentum :
165181 return {}
166182
167183 momentum = param_group .get ('betas' )[0 ] if use_betas else param_group .get ('momentum' , 0 )
168184 self .last_momentum_values [name ] = momentum
169185 return {name : momentum }
170186
171- def _find_names (self , lr_schedulers ) -> List [str ]:
172- # Create uniqe names in the case we have multiple of the same learning
173- # rate schduler + multiple parameter groups
187+ def _add_prefix (
188+ self , name : str , optimizer_cls : Type [Optimizer ], seen_optimizer_types : DefaultDict [Type [Optimizer ], int ]
189+ ) -> str :
190+ if optimizer_cls not in seen_optimizer_types :
191+ return name
192+ count = seen_optimizer_types [optimizer_cls ]
193+ return name + f'-{ count - 1 } ' if count > 1 else name
194+
195+ def _find_names (self , lr_schedulers : List , add_lr_sch_names : bool = True ) -> List [str ]:
196+ # Create unique names in the case we have multiple of the same learning
197+ # rate scheduler + multiple parameter groups
174198 names = []
199+ seen_optimizers = []
200+ seen_optimizer_types = defaultdict (int )
175201 for scheduler in lr_schedulers :
176202 sch = scheduler ['scheduler' ]
177203 if scheduler ['name' ] is not None :
178204 name = scheduler ['name' ]
179205 else :
180- opt_name = 'lr-' + sch .optimizer .__class__ .__name__
181- i , name = 1 , opt_name
206+ name = 'lr-' + sch .optimizer .__class__ .__name__
182207
183- # Multiple schduler of the same type
184- while True :
185- if name not in names :
186- break
187- i , name = i + 1 , f'{ opt_name } -{ i } '
208+ seen_optimizers .append (sch .optimizer )
209+ optimizer_cls = type (sch .optimizer )
210+ if scheduler ['name' ] is None :
211+ seen_optimizer_types [optimizer_cls ] += 1
188212
189- # Multiple param groups for the same schduler
213+ # Multiple param groups for the same scheduler
190214 param_groups = sch .optimizer .param_groups
191215
216+ name = self ._add_prefix (name , optimizer_cls , seen_optimizer_types )
217+
192218 if len (param_groups ) != 1 :
193- for i , pg in enumerate ( param_groups ):
219+ for i in range ( len ( param_groups ) ):
194220 temp = f'{ name } /pg{ i + 1 } '
195221 names .append (temp )
196222 else :
197223 names .append (name )
198224
199- self .lr_sch_names .append (name )
225+ if add_lr_sch_names :
226+ self .lr_sch_names .append (name )
200227
201228 return names
202229
0 commit comments