30
30
from lightning .pytorch .utilities .types import STEP_OUTPUT
31
31
32
32
33
- def _return_true (x : int ) -> bool :
34
- return True
35
-
36
-
37
- def _return_false (x : int ) -> bool :
38
- return False
39
-
40
-
41
33
class WeightAveraging (Callback ):
42
34
r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
43
35
(EMA) after each training step.
44
36
45
- The user should provide either `update_on_step` or `update_on_epoch`, a function that determines when the average
46
- model should be updated. If neither function is provided, the average model will be updated after every optimizer
47
- step.
37
+ The user can customize when the average model is updated by overriding the ``should_update()`` method.
48
38
49
39
During validation and after the training finishes, the current model parameters will be replaced with the averaged
50
40
values.
@@ -55,40 +45,44 @@ class WeightAveraging(Callback):
55
45
avg_fn: The averaging function used to update the parameters. The function must take in an
56
46
:class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If
57
47
``None``, an equally weighted average will be used.
58
- update_on_step: A function that takes the number of optimizer steps taken, and returns ``True`` if the average
59
- model should be updated.
60
- update_on_epoch: A function that takes the zero-based epoch number, and returns ``True`` if the average model
61
- should be updated.
62
48
63
49
"""
64
50
65
51
def __init__ (
66
52
self ,
67
53
device : Optional [Union [torch .device , int ]] = torch .device ("cpu" ),
68
54
avg_fn : Optional [Callable [[Tensor , Tensor , Union [Tensor , int ]], Tensor ]] = None ,
69
- update_on_step : Optional [Callable [[int ], bool ]] = None ,
70
- update_on_epoch : Optional [Callable [[int ], bool ]] = None ,
71
55
):
72
56
self ._device = device
73
57
self ._avg_fn = avg_fn
74
-
75
- if (update_on_step is None ) and (update_on_epoch is None ):
76
- self ._update_on_step : Callable [[int ], bool ] = _return_true
77
- self ._update_on_epoch : Callable [[int ], bool ] = _return_false
78
- else :
79
- self ._update_on_step = _return_false if update_on_step is None else update_on_step
80
- self ._update_on_epoch = _return_false if update_on_epoch is None else update_on_epoch
81
-
82
58
self ._average_model : Optional [AveragedModel ] = None
83
59
84
60
# Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures
85
- # that the average model will be first updated after the first optimizer step, which takes place after N batches
86
- # when using accumulate_grad_batches=N.
61
+ # that self.should_update() will be first called after the first optimizer step, which takes place after N
62
+ # batches when using accumulate_grad_batches=N.
87
63
self ._latest_update_step = 0
88
64
# The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a
89
- # negative value means that if update_on_step(0) returns True, the first update is after the first epoch.
65
+ # negative value means that if self.should_update(epoch_idx=0) returns True, the first update is after the first
66
+ # epoch.
90
67
self ._latest_update_epoch = - 1
91
68
69
+ def should_update (self , step_idx : Optional [int ] = None , epoch_idx : Optional [int ] = None ) -> bool :
70
+ """Called after every optimizer step and after every training epoch to check whether the average model should
71
+ be updated.
72
+
73
+ One of the arguments is set to the zero-based index of the last training step or epoch. The user can customize
74
+ when the average model gets updated by overriding this method.
75
+
76
+ Args:
77
+ step_idx: Index of the last optimizer step, or ``None`` when called at the epoch end.
78
+ epoch_idx: Index of the last epoch, or ``None`` when called after an optimizer step.
79
+
80
+ Returns:
81
+ ``True`` if the average model should be updated and ``False`` if not.
82
+
83
+ """
84
+ return step_idx is not None
85
+
92
86
def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
93
87
"""Called when fit, validate, test, predict, or tune begins.
94
88
@@ -109,7 +103,7 @@ def on_train_batch_end(
109
103
) -> None :
110
104
"""Called when a training batch ends.
111
105
112
- Updates the :class:`AveragedModel` parameters, if requested by ``update_on_step ()``.
106
+ Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update ()``.
113
107
114
108
Args:
115
109
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
@@ -119,22 +113,25 @@ def on_train_batch_end(
119
113
batch_idx: Index of the training batch.
120
114
121
115
"""
122
- if self ._update_on_step (trainer .global_step ) and (trainer .global_step > self ._latest_update_step ):
116
+ # trainer.global_step is the number of optimizer steps taken so far, i.e. 1 after the first optimizer step. To
117
+ # make step_idx consistent with epoch_idx, we'll pass a zero-based index.
118
+ step_idx = trainer .global_step - 1
119
+ if (trainer .global_step > self ._latest_update_step ) and self .should_update (step_idx = step_idx ):
123
120
assert self ._average_model is not None
124
121
self ._average_model .update_parameters (pl_module )
125
122
self ._latest_update_step = trainer .global_step
126
123
127
124
def on_train_epoch_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
128
125
"""Called when a training epoch ends.
129
126
130
- Updates the :class:`AveragedModel` parameters, if requested by ``update_on_epoch ()``.
127
+ Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update ()``.
131
128
132
129
Args:
133
130
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
134
131
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
135
132
136
133
"""
137
- if self . _update_on_epoch (trainer .current_epoch ) and ( trainer .current_epoch > self . _latest_update_epoch ):
134
+ if (trainer .current_epoch > self . _latest_update_epoch ) and self . should_update ( epoch_idx = trainer .current_epoch ):
138
135
assert self ._average_model is not None
139
136
self ._average_model .update_parameters (pl_module )
140
137
self ._latest_update_epoch = trainer .current_epoch
@@ -218,17 +215,21 @@ def on_save_checkpoint(
218
215
219
216
"""
220
217
if self ._average_model is None :
221
- raise Exception ("Trying to save a checkpoint, but no average model (outside fit). Don't know what to do." )
222
-
223
- rank_zero_info ("The average model parameters will be saved to the state_dict in the checkpoint." )
224
- average_model_state = self ._average_model .state_dict ()
225
- checkpoint ["current_model_state" ] = checkpoint ["state_dict" ]
226
- checkpoint ["state_dict" ] = {
227
- name [7 :]: value for name , value in average_model_state .items () if name .startswith ("module." )
228
- }
229
- checkpoint ["averaging_state" ] = {
230
- name : value for name , value in average_model_state .items () if not name .startswith ("module." )
231
- }
218
+ rank_zero_info (
219
+ "You're using the WeightAveraging callback, but saving a checkpoint outside the 'fit' stage. The state "
220
+ "of the WeightAveraging callback won't be saved in the checkpoint. If training has finished, the "
221
+ "average model parameters will be saved to the state_dict in the checkpoint."
222
+ )
223
+ else :
224
+ rank_zero_info ("The average model parameters will be saved to the state_dict in the checkpoint." )
225
+ average_model_state = self ._average_model .state_dict ()
226
+ checkpoint ["current_model_state" ] = checkpoint ["state_dict" ]
227
+ checkpoint ["state_dict" ] = {
228
+ name [7 :]: value for name , value in average_model_state .items () if name .startswith ("module." )
229
+ }
230
+ checkpoint ["averaging_state" ] = {
231
+ name : value for name , value in average_model_state .items () if not name .startswith ("module." )
232
+ }
232
233
233
234
def on_load_checkpoint (
234
235
self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , checkpoint : dict [str , Any ]
@@ -244,9 +245,12 @@ def on_load_checkpoint(
244
245
245
246
"""
246
247
if self ._average_model is None :
247
- raise Exception ("Trying to load a checkpoint, but no average model (outside fit). Don't know what to do." )
248
-
249
- if ("current_model_state" in checkpoint ) and ("averaging_state" in checkpoint ):
248
+ rank_zero_warn (
249
+ "You're using the WeightAveraging callback, but loading a checkpoint outside the 'fit' stage. The "
250
+ "WeightAveraging state cannot be restored. If you're using the checkpoint for prediction or testing, "
251
+ "you can ignore this warning. To disable the warning, remove the WeightAveraging callback."
252
+ )
253
+ elif ("current_model_state" in checkpoint ) and ("averaging_state" in checkpoint ):
250
254
rank_zero_info ("Found current_model_state in the checkpoint. This will be used to initialize the model." )
251
255
average_model_state = {"module." + name : value for name , value in checkpoint ["state_dict" ].items ()}
252
256
average_model_state |= checkpoint ["averaging_state" ]
0 commit comments