Skip to content

Generic weight averaging callback that supports EMA #20545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
43 changes: 32 additions & 11 deletions docs/source-pytorch/advanced/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,44 @@ Read more about :ref:`Configuring Gradient Clipping <configure_gradient_clipping

----------

***************************
Stochastic Weight Averaging
***************************
****************
Weight Averaging
****************

Stochastic Weight Averaging (SWA) can make your models generalize better at virtually no additional cost.
This can be used with both non-trained and trained models. The SWA procedure smooths the loss landscape thus making
it harder to end up in a local minimum during optimization.
Weight averaging methods such as Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA) can make your
models generalize better at virtually no additional cost. Averaging smooths the loss landscape thus making it harder to
end up in a local minimum during optimization.

For a more detailed explanation of SWA and how it works,
read `this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.
Lightning provides two callbacks to facilitate weight averaging. :class:`~lightning.pytorch.callbacks.WeightAveraging`
is a generic callback that wraps the
`AveragedModel <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.AveragedModel.html>`__ class from
PyTorch. It allows SWA, EMA, or a custom averaging strategy to be used and it can be customized to run at specific steps
or epochs.

.. seealso:: The :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback
The older :class:`~lightning.pytorch.callbacks.StochasticWeightAveraging` callback is specific to SWA. It starts the SWA
procedure after a certain number of epochs and always runs on every epoch. Additionally, it switches to a constant
learning rate schedule (`SWALR <https://pytorch.org/docs/stable/generated/torch.optim.swa_utils.SWALR.html>`__) when the
procedure starts.

.. seealso::
For a more detailed explanation of SWA and how it works, read
`this post <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ by the PyTorch team.

.. testcode::

# Enable Stochastic Weight Averaging using the callback
trainer = Trainer(callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)])
from lightning.pytorch.callbacks import StochasticWeightAveraging, WeightAveraging
from torch.optim.swa_utils import get_ema_avg_fn

# Enable Exponential Moving Average after 100 steps
class EMAWeightAveraging(WeightAveraging):
def __init__(self):
super().__init__(avg_fn=get_ema_avg_fn())
def should_update(self, step_idx=None, epoch_idx=None):
return (step_idx is not None) and (step_idx >= 100)
trainer = Trainer(callbacks=EMAWeightAveraging())

# Enable Stochastic Weight Averaging after 10 epochs with learning rate 0.01
trainer = Trainer(callbacks=StochasticWeightAveraging(swa_epoch_start=10, swa_lrs=0.01))

----------

Expand Down
1 change: 1 addition & 0 deletions docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ callbacks
ThroughputMonitor
Timer
TQDMProgressBar
WeightAveraging

cli
-----
Expand Down
1 change: 1 addition & 0 deletions docs/source-pytorch/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Lightning has a few built-in callbacks.
StochasticWeightAveraging
Timer
TQDMProgressBar
WeightAveraging

----------

Expand Down
16 changes: 8 additions & 8 deletions docs/source-pytorch/glossary/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@
Strategy registry <../advanced/strategy_registry>
Strategy integrations <../integrations/strategies/index>
Style guide <../starter/style_guide>
SWA <../advanced/training_tricks>
SLURM <../clouds/cluster_advanced>
Tensor Parallel <../advanced/model_parallel/tp>
Transfer learning <../advanced/transfer_learning>
Trainer <../common/trainer>
TorchRun (TorchElastic) <../clouds/cluster_intermediate_2>
Warnings <../advanced/warnings>
Weight averaging <../advanced/training_tricks>


########
Expand Down Expand Up @@ -326,13 +326,6 @@ Glossary
:button_link: ../starter/style_guide.html
:height: 100

.. displayitem::
:header: SWA
:description: Stochastic Weight Averaging (SWA) can make your models generalize better
:col_css: col-md-12
:button_link: ../advanced/training_tricks.html#stochastic-weight-averaging
:height: 100

.. displayitem::
:header: SLURM
:description: Simple Linux Utility for Resource Management, or simply Slurm, is a free and open-source job scheduler for Linux clusters
Expand Down Expand Up @@ -375,6 +368,13 @@ Glossary
:button_link: ../advanced/warnings.html
:height: 100

.. displayitem::
:header: Weight averaging
:description: Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) can make your models generalize better
:col_css: col-md-12
:button_link: ../advanced/training_tricks.html#weight-averaging
:height: 100

.. raw:: html

</div>
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/model/build_model_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Enable advanced training features using Trainer arguments. These are SOTA techni
)

# access the latest state of the art techniques
trainer = Trainer(callbacks=[StochasticWeightAveraging(...)])
trainer = Trainer(callbacks=[WeightAveraging(...)])

----

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/starter/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ Enable advanced training features using Trainer arguments. These are state-of-th
)

# access the latest state of the art techniques
trainer = L.Trainer(callbacks=[StochasticWeightAveraging(...)])
trainer = L.Trainer(callbacks=[WeightAveraging(...)])

----

Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- WeightAveraging callback that wraps the PyTorch AveragedModel class ([#20545](https://github.com/Lightning-AI/pytorch-lightning/pull/20545))
- Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593))


Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.callbacks.weight_averaging import WeightAveraging

__all__ = [
"BackboneFinetuning",
Expand All @@ -58,4 +59,5 @@
"ThroughputMonitor",
"Timer",
"TQDMProgressBar",
"WeightAveraging",
]
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(

.. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch.

See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Stochastic Weight Averaging>`
See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Weight Averaging>`.

Arguments:

Expand Down
Loading
Loading