Skip to content

Commit 51b9a06

Browse files
author
Seppo Enarvi
committed
The user can customize WeightAveraging updates by overriding the should_update() method
1 parent 075bfcf commit 51b9a06

File tree

2 files changed

+54
-50
lines changed

2 files changed

+54
-50
lines changed

src/lightning/pytorch/callbacks/weight_averaging.py

Lines changed: 50 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,11 @@
3030
from lightning.pytorch.utilities.types import STEP_OUTPUT
3131

3232

33-
def _return_true(x: int) -> bool:
34-
return True
35-
36-
37-
def _return_false(x: int) -> bool:
38-
return False
39-
40-
4133
class WeightAveraging(Callback):
4234
r"""A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average
4335
(EMA) after each training step.
4436
45-
The user should provide either `update_on_step` or `update_on_epoch`, a function that determines when the average
46-
model should be updated. If neither function is provided, the average model will be updated after every optimizer
47-
step.
37+
The user can customize when the average model is updated by overriding the ``should_update()`` method.
4838
4939
During validation and after the training finishes, the current model parameters will be replaced with the averaged
5040
values.
@@ -55,40 +45,44 @@ class WeightAveraging(Callback):
5545
avg_fn: The averaging function used to update the parameters. The function must take in an
5646
:class:`AveragedModel` parameter, a current model parameter, and the number of models already averaged. If
5747
``None``, an equally weighted average will be used.
58-
update_on_step: A function that takes the number of optimizer steps taken, and returns ``True`` if the average
59-
model should be updated.
60-
update_on_epoch: A function that takes the zero-based epoch number, and returns ``True`` if the average model
61-
should be updated.
6248
6349
"""
6450

6551
def __init__(
6652
self,
6753
device: Optional[Union[torch.device, int]] = torch.device("cpu"),
6854
avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None,
69-
update_on_step: Optional[Callable[[int], bool]] = None,
70-
update_on_epoch: Optional[Callable[[int], bool]] = None,
7155
):
7256
self._device = device
7357
self._avg_fn = avg_fn
74-
75-
if (update_on_step is None) and (update_on_epoch is None):
76-
self._update_on_step: Callable[[int], bool] = _return_true
77-
self._update_on_epoch: Callable[[int], bool] = _return_false
78-
else:
79-
self._update_on_step = _return_false if update_on_step is None else update_on_step
80-
self._update_on_epoch = _return_false if update_on_epoch is None else update_on_epoch
81-
8258
self._average_model: Optional[AveragedModel] = None
8359

8460
# Number of optimizer steps taken, when the average model was last updated. Initializing this with zero ensures
85-
# that the average model will be first updated after the first optimizer step, which takes place after N batches
86-
# when using accumulate_grad_batches=N.
61+
# that self.should_update() will be first called after the first optimizer step, which takes place after N
62+
# batches when using accumulate_grad_batches=N.
8763
self._latest_update_step = 0
8864
# The epoch after which the average model was last updated. The first epoch is 0, so initializing this to a
89-
# negative value means that if update_on_step(0) returns True, the first update is after the first epoch.
65+
# negative value means that if self.should_update(epoch_idx=0) returns True, the first update is after the first
66+
# epoch.
9067
self._latest_update_epoch = -1
9168

69+
def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool:
70+
"""Called after every optimizer step and after every training epoch to check whether the average model should
71+
be updated.
72+
73+
One of the arguments is set to the zero-based index of the last training step or epoch. The user can customize
74+
when the average model gets updated by overriding this method.
75+
76+
Args:
77+
step_idx: Index of the last optimizer step, or ``None`` when called at the epoch end.
78+
epoch_idx: Index of the last epoch, or ``None`` when called after an optimizer step.
79+
80+
Returns:
81+
``True`` if the average model should be updated and ``False`` if not.
82+
83+
"""
84+
return step_idx is not None
85+
9286
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
9387
"""Called when fit, validate, test, predict, or tune begins.
9488
@@ -109,7 +103,7 @@ def on_train_batch_end(
109103
) -> None:
110104
"""Called when a training batch ends.
111105
112-
Updates the :class:`AveragedModel` parameters, if requested by ``update_on_step()``.
106+
Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update()``.
113107
114108
Args:
115109
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
@@ -119,22 +113,25 @@ def on_train_batch_end(
119113
batch_idx: Index of the training batch.
120114
121115
"""
122-
if self._update_on_step(trainer.global_step) and (trainer.global_step > self._latest_update_step):
116+
# trainer.global_step is the number of optimizer steps taken so far, i.e. 1 after the first optimizer step. To
117+
# make step_idx consistent with epoch_idx, we'll pass a zero-based index.
118+
step_idx = trainer.global_step - 1
119+
if (trainer.global_step > self._latest_update_step) and self.should_update(step_idx=step_idx):
123120
assert self._average_model is not None
124121
self._average_model.update_parameters(pl_module)
125122
self._latest_update_step = trainer.global_step
126123

127124
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
128125
"""Called when a training epoch ends.
129126
130-
Updates the :class:`AveragedModel` parameters, if requested by ``update_on_epoch()``.
127+
Updates the :class:`AveragedModel` parameters, if requested by ``self.should_update()``.
131128
132129
Args:
133130
trainer: The current :class:`~lightning.pytorch.trainer.trainer.Trainer` instance.
134131
pl_module: The current :class:`~lightning.pytorch.core.LightningModule` instance.
135132
136133
"""
137-
if self._update_on_epoch(trainer.current_epoch) and (trainer.current_epoch > self._latest_update_epoch):
134+
if (trainer.current_epoch > self._latest_update_epoch) and self.should_update(epoch_idx=trainer.current_epoch):
138135
assert self._average_model is not None
139136
self._average_model.update_parameters(pl_module)
140137
self._latest_update_epoch = trainer.current_epoch
@@ -218,17 +215,21 @@ def on_save_checkpoint(
218215
219216
"""
220217
if self._average_model is None:
221-
raise Exception("Trying to save a checkpoint, but no average model (outside fit). Don't know what to do.")
222-
223-
rank_zero_info("The average model parameters will be saved to the state_dict in the checkpoint.")
224-
average_model_state = self._average_model.state_dict()
225-
checkpoint["current_model_state"] = checkpoint["state_dict"]
226-
checkpoint["state_dict"] = {
227-
name[7:]: value for name, value in average_model_state.items() if name.startswith("module.")
228-
}
229-
checkpoint["averaging_state"] = {
230-
name: value for name, value in average_model_state.items() if not name.startswith("module.")
231-
}
218+
rank_zero_info(
219+
"You're using the WeightAveraging callback, but saving a checkpoint outside the 'fit' stage. The state "
220+
"of the WeightAveraging callback won't be saved in the checkpoint. If training has finished, the "
221+
"average model parameters will be saved to the state_dict in the checkpoint."
222+
)
223+
else:
224+
rank_zero_info("The average model parameters will be saved to the state_dict in the checkpoint.")
225+
average_model_state = self._average_model.state_dict()
226+
checkpoint["current_model_state"] = checkpoint["state_dict"]
227+
checkpoint["state_dict"] = {
228+
name[7:]: value for name, value in average_model_state.items() if name.startswith("module.")
229+
}
230+
checkpoint["averaging_state"] = {
231+
name: value for name, value in average_model_state.items() if not name.startswith("module.")
232+
}
232233

233234
def on_load_checkpoint(
234235
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any]
@@ -244,9 +245,12 @@ def on_load_checkpoint(
244245
245246
"""
246247
if self._average_model is None:
247-
raise Exception("Trying to load a checkpoint, but no average model (outside fit). Don't know what to do.")
248-
249-
if ("current_model_state" in checkpoint) and ("averaging_state" in checkpoint):
248+
rank_zero_warn(
249+
"You're using the WeightAveraging callback, but loading a checkpoint outside the 'fit' stage. The "
250+
"WeightAveraging state cannot be restored. If you're using the checkpoint for prediction or testing, "
251+
"you can ignore this warning. To disable the warning, remove the WeightAveraging callback."
252+
)
253+
elif ("current_model_state" in checkpoint) and ("averaging_state" in checkpoint):
250254
rank_zero_info("Found current_model_state in the checkpoint. This will be used to initialize the model.")
251255
average_model_state = {"module." + name: value for name, value in checkpoint["state_dict"].items()}
252256
average_model_state |= checkpoint["averaging_state"]

tests/tests_pytorch/callbacks/test_weight_averaging.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,15 @@ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
119119

120120
class SWATestCallback(WeightAveraging):
121121
def __init__(self, **kwargs: Any) -> None:
122-
avg_fn = get_swa_avg_fn()
123-
update_on_epoch = lambda x: x in (3, 5, 7)
124-
super().__init__(avg_fn=avg_fn, update_on_epoch=update_on_epoch, **kwargs)
125-
122+
super().__init__(avg_fn=get_swa_avg_fn(), **kwargs)
126123
self.swap_calls = 0
127124
self.copy_calls = 0
128125
# Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0.
129126
self.first_epoch: Optional[int] = None
130127

128+
def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool:
129+
return epoch_idx in (3, 5, 7)
130+
131131
def _swap_models(self, *args: Any, **kwargs: Any):
132132
self.swap_calls += 1
133133
return super()._swap_models(*args, **kwargs)

0 commit comments

Comments
 (0)