Skip to content

Generic weight averaging callback that supports EMA #20545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

senarvi
Copy link
Contributor

@senarvi senarvi commented Jan 14, 2025

A callback that updates an AveragedModel after every training step

What does this PR do?

This is similar to the existing StochasticWeightAveraging callback, but wraps the AveragedModel class from PyTorch. Reduced code duplication means easier maintenance. Also, any averaging function can be used. By default, the callback does averaging on every step, but this can be customized by overriding the should_update(step_idx, epoch_idx) method.

Fixes #10914

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs) => Discussed in issue Add feature Exponential Moving Average (EMA) #10914
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request? => There are none.
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--20545.org.readthedocs.build/en/20545/

@github-actions github-actions bot added docs Documentation related pl Generic label for PyTorch Lightning package labels Jan 14, 2025
@lantiga
Copy link
Collaborator

lantiga commented Jan 14, 2025

Hey @senarvi, this looks great!

I saw you already added support for saving and resuming which is great. There are many scenarios there (save every n steps, time-based, every epoch, etc) let's make sure we cover them all (for inspiration, we added quite a few tests here #20379)

we could still have different callbacks ("StepwiseAveragingCallback" and "EpochwiseAveragingCallback")

No I think it's better to have one with configurable averaging flags, more lightning-esque

Constructs the AveragedModel with use_buffers=True, so that an extra step is not needed for updating the batch normalization statistics. StochasticWeightAveraging performs an extra step in the end. Consequently the implementation is significantly more complex and it's difficult to make sure that it works in all cases. Should we add this as an option in this class too?

I think this is ok, but my doubt with forcing use_buffers to be true is what happens when a user has a module with buffers in it that are not meant to be averaged. I guess at that point they will probably be the same over time (e.g. the RoPE cache), but that's not really a guarantee.

Wdyt about this? I don't necessarily want to make the implementation more complex, so this is just for discussion.

Updates the average model after every step. StochasticWeightAveraging updates the average model after every epoch, and I recall that the original paper updated it only at certain points (the learning rate minima). I guess it would be nice to be able to select whether the average model will be updated after every step, after every epoch, or after certain epochs. Then we would need only one callback and we could remove the StochasticWeightAveraging callback, but would it make this class too complex?

It would be nice to make it configurable, and probably users will want to get to some minimum and then start averaging. The criteria to do so may be very bespoke, so maybe allowing the user to implement a custom hook to decide whether to start averaging or whether to average at a given step would be super handy. Otherwise I'm expecting users will train for some time, save a checkpoint, then reload with this callback added to the trainer and start averaging. Which is totally fine but it requires you to stop and resume.

Regarding removing the StochasticWeightAveraging callback, I don't necessarily see that happening. We have a pretty strong commitment to backward compatibility at this point, so keeping that in with a notice to just use this one will not hurt.

@senarvi
Copy link
Contributor Author

senarvi commented Jan 15, 2025

I think this is ok, but my doubt with forcing use_buffers to be true is what happens when a user has a module with buffers in it that are not meant to be averaged. I guess at that point they will probably be the same over time (e.g. the RoPE cache), but that's not really a guarantee.

That's a good point. I don't know what would be a good solution.

Updates the average model after every step. StochasticWeightAveraging updates the average model after every epoch, and I recall that the original paper updated it only at certain points (the learning rate minima). I guess it would be nice to be able to select whether the average model will be updated after every step, after every epoch, or after certain epochs. Then we would need only one callback and we could remove the StochasticWeightAveraging callback, but would it make this class too complex?

It would be nice to make it configurable, and probably users will want to get to some minimum and then start averaging. The criteria to do so may be very bespoke, so maybe allowing the user to implement a custom hook to decide whether to start averaging or whether to average at a given step would be super handy. Otherwise I'm expecting users will train for some time, save a checkpoint, then reload with this callback added to the trainer and start averaging. Which is totally fine but it requires you to stop and resume.

That's an interesting idea. We could have the user pass a function update_on_step(global_step) or update_on_epoch(epoch) that returns a boolean. After each optimizer step and after each epoch we would call the function to check whether we should update the average model.

It seems that AveragedModel will copy the current model parameters when called the first time, and update the average on subsequent calls. This means that the first average is computed when update_on_step() or update_on_epoch() returns True for the second time. I don't see a better alternative.

I checked how StochasticWeightAveraging does this and I think it doesn't work correctly. It only ever updates the average model parameters in on_train_epoch_start(), so the average is not updated after the last epoch. Just shows why I'd like to keep the logic as simple as possible.

@cyanic-selkie
Copy link

Hi, I have a couple questions.

  1. You added the on_validation_epoch_start and on_validation_epoch_end hooks to swap the weights, but shouldn't the same happen for test?
  2. In my current workflow I have a separate script that does the model exporting to ONNX. It's short, and really the only Lightning specific thing is the MyLightningModule.load_from_checkpoint(...) method. Since the averaged weights are a part of the callback, I would have to instantiate the trainer for the weights to be loaded. And even then, I wouldn't have a function I could call to explicitly swap the weights (since _swap_weights is private and not really accessible). So, my question is, can we have some sort of an API, outside of the trainer, that can load the averaged weights instead of the regular weights? Perhaps adding some sort of a parameter to the load_from_checkpoint method?

@senarvi
Copy link
Contributor Author

senarvi commented Jan 16, 2025

Hi @cyanic-selkie

During training (stage=fit), the actual LightningModule is what we update using the optimizer (I call it the current model) and an AveragedModel is maintained in the background (I call it the average model).

I assume that validation is only called during training. Before and after validation we swap the current model and the average model, so the average model will be validated.

When saving a checkpoint, we save the average model parameters in the state_dict. So if you later load the checkpoint without WeightAveraging callback and run a test or export to ONNX, you will be using the average parameters.

When training ends, we copy the average model parameters to the current model. So if you run a test or export to ONNX after training, you will be using the average parameters.

That's the idea at least. I'm not confident that I have thought about every possible corner case. It would be great if you could test that it works in your case.

@cyanic-selkie
Copy link

@senarvi Ah! Thanks for the clarification, I should've checked the code out more carefully. I tried your branch out on a quantization aware training enabled model with ONNX export at the end and everything is working beautifully! I hope this gets merged quickly.

@senarvi senarvi force-pushed the generic-weight-averaging branch from efc77dc to 0010492 Compare January 23, 2025 16:07
@senarvi
Copy link
Contributor Author

senarvi commented Jan 23, 2025

The user can now provide either the update_on_step or the update_on_epoch argument. (In theory also both.) It should be a function that takes the step/epoch number and returns True if the average model should be updated at that point of time.

For example:

update_on_step = lambda x: x > 100

or

update_on_epoch = lambda x: x in (3, 5, 7)

Using update_on_epoch, SWA should be possible. I added one unit test for SWA.

I tested EMA in an actual learning task and it gave an improvement, so I'm starting to be more confident that this works.

I think the biggest question that is still left is whether it's a problem that we force use_buffers=True. It would be nice if we could provide the option to instead call update_bn() after training and we wouldn't have to duplicate any of that code. That function takes a data loader and iterates through the data. I can imagine that passing the Trainer's data loader might not work in all cases. We could also leave calling this function to the user.

StochasticWeightAveraging increments the number of epochs in on_fit_start() and during the extra epoch disables the backward pass. I could also copy the code from that class, but there are some details that I don't understand, and I'm not that excited of copying code that I don't fully understand.

@tchaton I think you contributed the StochasticWeightAveraging callback, maybe you have some insight?

* A callback that updates a torch.optim.swa_utils.AveragedModel after specific steps or epochs.
* The user can provide a callback that defines after which steps or epochs the average model is updated.
@senarvi senarvi force-pushed the generic-weight-averaging branch from 5f34205 to c8d50bd Compare January 23, 2025 18:00
@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Jan 23, 2025
@cyanic-selkie
Copy link

Is there anything blocking this from being merged?

@senarvi senarvi changed the title Generic weight averaging callback that supports EMA [wip] Generic weight averaging callback that supports EMA Feb 2, 2025
@senarvi senarvi marked this pull request as ready for review February 2, 2025 21:21
@senarvi
Copy link
Contributor Author

senarvi commented Feb 2, 2025

I marked this ready for review. There were no comments whether it's a problem that we force use_buffers=True. Would it make sense to merge this now and perhaps introduce such option later based on the feedback that we receive?

Copy link

codecov bot commented Feb 2, 2025

Codecov Report

Attention: Patch coverage is 94.68085% with 5 lines in your changes missing coverage. Please review.

Project coverage is 79%. Comparing base (831870a) to head (5deb0bb).

❗ There is a different number of reports uploaded between BASE (831870a) and HEAD (5deb0bb). Click for more details.

HEAD has 349 uploads less than BASE
Flag BASE (831870a) HEAD (5deb0bb)
cpu 105 27
python3.10 24 6
lightning_fabric 26 0
pytest 57 0
python 12 3
python3.12 10 3
python3.12.7 35 9
lightning 60 15
python3.11 24 6
gpu 4 0
pytorch2.1 12 6
pytorch_lightning 23 12
pytest-full 52 27
pytorch2.2.2 6 3
pytorch2.3 6 3
pytorch2.5 6 3
pytorch2.6 6 3
pytorch2.4.1 6 3
pytorch2.5.1 5 3
pytorch2.7 5 3
Additional details and impacted files
@@            Coverage Diff            @@
##           master   #20545     +/-   ##
=========================================
- Coverage      87%      79%     -8%     
=========================================
  Files         268      266      -2     
  Lines       23449    23488     +39     
=========================================
- Hits        20389    18475   -1914     
- Misses       3060     5013   +1953     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@lantiga lantiga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solid contribution @senarvi! I added a few comments (most are quick to address, let me know what you can do here vs follow up PR), but overall looks great.

checkpoint["state_dict"] = {
name[7:]: value for name, value in average_model_state.items() if name.startswith("module.")
}
checkpoint["averaging_state"] = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
checkpoint["averaging_state"] = {
checkpoint["averaged_state"] = {

I get that it might be still "averaging" : ), but it's in fact "averaged" up to the current iterations. We can called it "average" model if it sounds better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lantiga the name is a bit confusing, but it means the state of the averaging process, not the average model. This includes the state variables of the AveragedModel class, excluding the module (i.e. n_averaged). The average model is saved in state_dict, so whatever we'll do with the checkpoint, we'll use the average model. The current model state is saved in current_model_state, so that we can continue training with the WeightAveraging callback from the previous state. If you have a less confusing name for the "averaging state variables, excluding the averaged model parameters", I can change it.


"""
if self._average_model is None:
raise Exception("Trying to load a checkpoint, but no average model (outside fit). Don't know what to do.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is hard to understand for a user if they don't know the details of the callback.

Suggested change
raise Exception("Trying to load a checkpoint, but no average model (outside fit). Don't know what to do.")
raise Exception("Trying to load a checkpoint using the WeightAveraging callback outside the `fit` stage. The WeightAveraging callback can only be used in the `fit` stage.")

I'm wondering: instead of raising we could just load the average model e.g. for predict. This will avoid forcing users to remove the callback from the Trainer.

Copy link
Contributor Author

@senarvi senarvi Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lantiga I guess I just wasn't sure in which situation this callback would be called outside fit, but yes, if the user calls Trainer.validate/test/predict(ckpt_path=...), I believe this will be called and the best thing to do would be to load the average model. The average model will be loaded if we don't do anything. Maybe just display a warning in that case.

I guess on_save_checkpoint() can also be called outside fit - if the user calls Trainer.save_checkpoint() after training. In that case we also don't have to do anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lantiga This is what I did. Please check if you think the messages are clear now.

assert trainer.lightning_module == model


def _train_and_resume(tmp_path: str, crash_on_epoch: int, use_ddp: bool = False) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tests that we can crash and resume, but afaict it doesn't test whether the resulting averaging is equivalent. We can harden this in a subsequent PR, but it is important to know for sure that averaging works if I stop training and resume while averages are being taken, irrespective of where I stop and resume in the lifecycle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to still improve the test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lantiga now I test that after stopping and resuming we get the same final model. The parameters are not identical - I have to use atol=0.001 - but they are close enough so that I think that the difference comes from some random change instead of a bug in restoring the checkpoint. I don't know what could cause the difference, though. I pass deterministic=True to Trainer. I'm curious if you have some ideas, or if you think that that's close enough.

@lantiga
Copy link
Collaborator

lantiga commented Feb 3, 2025

BTW: I think it's totally fine to merge this as is and open an issue to gather discussions about averaging buffers.

The other question I have (for the future) is related to fitting both models on GPU. It may make sense to give the ability to keep the AveragedModel on a different device (e.g. cpu) to keep the callback usable with larger models.

@senarvi
Copy link
Contributor Author

senarvi commented Feb 4, 2025

The other question I have (for the future) is related to fitting both models on GPU. It may make sense to give the ability to keep the AveragedModel on a different device (e.g. cpu) to keep the callback usable with larger models.

There's a device argument already, and actually the default is cpu - as with StochasticWeightAveraging.

@senarvi senarvi force-pushed the generic-weight-averaging branch from c6856eb to 42d91cd Compare February 10, 2025 10:30
@h2o64
Copy link

h2o64 commented Feb 21, 2025

Hi! Thanks for this great PR. The current implementation only leverages avg_fn argument should it also consider the in-place version multi_avg_fn ?

@senarvi
Copy link
Contributor Author

senarvi commented Feb 21, 2025

Hi! Thanks for this great PR. The current implementation only leverages avg_fn argument should it also consider the in-place version multi_avg_fn ?

I think we could just pass **averaged_model_kwargs. I'll look into it over the weekend.

- The user can specify when to update the average model by overriding the should_update() method
- Any keyword arguments will be passed to the AveragedModel constructor
@senarvi senarvi force-pushed the generic-weight-averaging branch from 8953a18 to 822231f Compare April 3, 2025 18:27
@senarvi senarvi force-pushed the generic-weight-averaging branch from cf34483 to 5deb0bb Compare April 4, 2025 07:22
@senarvi
Copy link
Contributor Author

senarvi commented Apr 4, 2025

Thank you @catalpaaa for checking the PR! I removed the logging. I think it's not needed anymore. Originally, my biggest fear was that the weights are not transferred correctly, which could go unnoticed. And actually... I was kind of right.

The test_ema_resume test was failing in the CI pipeline. I had to set the tolerance to a suspiciously high number when comparing the model weights. So I started once again looking into where that difference comes from.

test_ema_resume trains two models, one without interruption and one with an intentional crash and recovery from a checkpoint after N epochs. I noticed that the weights that I load from the checkpoint are not the same that the model has in the beginning of epoch N+1. It appears that Lightning has already loaded the model weights from "state_dict" when entering the on_load_checkpoint() callback. I had assumed that I can swap the "current_model_state" to "state_dict" in the callback, but in fact I have to reload the model state from "current_model_state".

I'm super happy that I found this bug. Now, I think, the only test that's failing is tests_pytorch/callbacks/test_pruning.py::test_pruning_callback_ddp[True-True] in PyTorch | oldest. I don't find other error except Bash exited with code '1'.

@catalpaaa
Copy link

catalpaaa commented Apr 11, 2025

When training with TPUs, I add

callbacks.append(WeightAveraging(avg_fn=get_ema_avg_fn(0.9999)))

to my call backs. I will get the following errors:

concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/usr/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'get_ema_avg_fn.<locals>.ema_update'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/Theia/train.py", line 139, in <module>
    main()
  File "/root/Theia/train.py", line 135, in main
    trainer.fit(model, datamodule=data)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/trainer.py", line 567, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/launchers/xla.py", line 98, in launch
    process_context = xmp.spawn(
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 39, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/pjrt.py", line 213, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/pjrt.py", line 169, in run_multiprocess
    replica_results = list(
  File "/usr/local/lib/python3.10/dist-packages/torch_xla/_internal/pjrt.py", line 170, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 575, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/usr/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'get_ema_avg_fn.<locals>.ema_update'

The code runs fine with single GPU.

If I pull the decay function out from pytorch's get_ema_avg_fn:

@torch.no_grad()
def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
    decay = 0.9999
    return decay * ema_param + (1 - decay) * current_param

callbacks.append(WeightAveraging(avg_fn=ema_update))

Everything runs fine. Please let me know if I'm using this wrong.

@senarvi
Copy link
Contributor Author

senarvi commented Apr 12, 2025

Thanks @catalpaaa . You're using it correctly. I noticed the same thing with ddp_spawn, so I had to use a similar workaround in the unit tests.

The problem seems to be caused by get_ema_avg_fn() returning a closure. I don't know if there's anything I can do about it. If get_ema_avg_fn was a class instead of a function that returns a closure, like in the unit tests, the problem would be solved. Since the whole point was to avoid duplicating code between pytorch and lightning, maybe it would be best to fix this in pytorch.

@catalpaaa
Copy link

I guess pickle just hates function in function :(

@amorehead
Copy link
Contributor

amorehead commented Apr 23, 2025

*Another respectful bump

@amorehead
Copy link
Contributor

amorehead commented Apr 24, 2025

@senarvi, while testing out your callback in a local codebase of mine, I discovered an edge case that should be simple to address. Namely, when one is using a LightningModule's configure_model hook to efficiently initialize one's model weights (e.g., when the model is too large to fit into CPU memory), this callback will currently try to wrap pl_module inside AveragedModel without the model weights loaded (since configure_callbacks is called in Lightning before configure_model is). As such, to ensure that users of Lightning are always loading their model's weights before trying to wrap them in AveragedModel, you can simply change the setup hook for WeightAveraging to read as follows:

@override
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
    """Called when fit, validate, test, predict, or tune begins.

    Creates an `AveragedModel` when fit begins.

    Args:
        trainer: The current `~lightning.pytorch.trainer.trainer.Trainer` instance.
        pl_module: The current `~lightning.pytorch.core.LightningModule` instance.
        stage: The `~lightning.pytorch.trainer.trainer.Trainer` state.
    """
    if stage == "fit":
        device = self._device or pl_module.device
        pl_module.configure_model()  # add this to make sure the model is wrapped correctly
        self._average_model = AveragedModel(
            model=pl_module,
            device=device,
            use_buffers=self._use_buffers,
            **self._kwargs,
        )

@amorehead
Copy link
Contributor

amorehead commented Apr 24, 2025

Also, I've filled in the EMAWeightAveraging example into what may be a nice default callback configuration for users. Feel free to have a look:

import torch

from torch.optim.swa_utils import get_ema_avg_fn
from typing import Optional, Union


class EMAWeightAveraging(WeightAveraging):
    """Exponential Moving Average (EMA) Weight Averaging callback."""

    def __init__(
        self,
        device: Optional[Union[torch.device, str, int]] = "cpu",
        use_buffers: bool = True,
        decay: float = 0.999,
        update_every_n_steps: int = 1,
        update_starting_at_step: Optional[int] = None,
        update_starting_at_epoch: Optional[int] = None,
        **kwargs: Any,
    ):
        super().__init__(
            device=device,
            use_buffers=use_buffers,
            **kwargs,
            avg_fn=get_ema_avg_fn(decay=decay),
        )

        self.update_every_n_steps = update_every_n_steps
        self.update_starting_at_step = update_starting_at_step
        self.update_starting_at_epoch = update_starting_at_epoch

    def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None):
        """Decide when to update the model weights.

        Args:
            step_idx: The current step index.
            epoch_idx: The current epoch index.
        Returns:
            bool: True if the model weights should be updated, False otherwise.
        """
        if step_idx is not None:
            # Check step-based conditions only if we have a valid step_idx
            meets_step_requirement = (
                self.update_starting_at_step is None
                or step_idx >= self.update_starting_at_step
            )
            meets_step_frequency = (
                self.update_every_n_steps > 0
                and step_idx % self.update_every_n_steps == 0
            )
            if meets_step_requirement and meets_step_frequency:
                return True

        if epoch_idx is not None:
            # Check epoch-based condition only if we specify one
            meets_epoch_requirement = (
                self.update_starting_at_epoch is not None
                and epoch_idx >= self.update_starting_at_epoch
            )
            if meets_epoch_requirement:
                return True

        return False

@senarvi
Copy link
Contributor Author

senarvi commented Apr 24, 2025

That's a very good catch @amorehead ! Thanks for testing this before it's merged.

I wasn't familiar with the configure_model hook, so I wanted to understand that we're using it correctly. According to the documentation, it's called in a "strategy and precision aware context", so that when using a sharded strategy, the model is sharded instantly. Normally it's called in _call_configure_model() like this:

with (
    trainer.strategy.tensor_init_context(),
    trainer.strategy.model_sharded_context(),
    trainer.precision_plugin.module_init_context(),
):
    _call_lightning_module_hook(trainer, "configure_model")

Should we call _call_configure_model(trainer) instead of pl_module.configure_model()? Otherwise, I guess the model is not sharded at this point.

@amorehead
Copy link
Contributor

amorehead commented Apr 24, 2025

@senarvi, you are right! I just noticed that the model will not be sharded unless Lightning's private method lightning.pytorch.trainer.call._call_configure_model is called instead of pl_module.configure_model. However, I don't think this will be enough to fully support model-parallel training strategies (such as FSDP2, which I'm currently trying to test with this callback), since each GPU rank will contain only a portion of the model weights (and there's currently only a single _average_model instance being updated).

  1. Currently, if one tries to wrap the pl_module with AveragedModel after sharding the model with _call_configure_model, they will likely run into this PyTorch issue raised by FSDP's use of copy.deepcopy on the model weights.
  2. Another approach is to follow this recent PyTorch issue and try to wrap the (full, potentially large weights) pl_module with AveragedModel before sharding the model weights (by coincidentally calling the original pl_module.configure_model method). Then, the goal would be to (identically) shard both the pl_module's model weights as well as the _average_model weights, such that calls to self._average_model.update_parameters in on_train_batch_end and on_train_epoch_end would yield updates to the sharded parameters stored on each GPU rank.

To implement approach 2 above, we would need to (1) keep the call to pl_module.configure_model as it is and (2) figure out how to shard self._average_model identically as the model weights in pl_module are sharded. One naive approach to sharding self._average_model would be to refactor this callback to instead store it as a property of pl_module so it gets automatically sharded by Lightning's (automatic) call to lightning.pytorch.trainer.call._call_configure_model, but this could get tricky and may complicate this initial implementation of the callback (which should work well for DDP-based (data-parallel) training).

Update: On further reflection, I think the only way to make this callback compatible with model-parallel training would be to store _average_model as a property of pl_module everywhere in this callback's source code. Otherwise, we would be assuming there is only one copy of the averaged model weights, when in fact there can be many averaged (fragment) model weights. Fortunately, it looks like everywhere _average_model is referenced, a pl_module instance is available. Now, what I'm unsure of is whether this pl_module instance (when training with model parallelism) contains only the model weights belonging to the shard of an individual GPU rank (or whether this points to the full model weights somehow).

@senarvi
Copy link
Contributor Author

senarvi commented Apr 24, 2025

@amorehead , if I understand correctly, each GPU rank is running in its own process that contains a portion (shard) of the model weights and an AveragedModel instance. Because of issue 1, AveragedModel is not able to copy the parameters of an FSDP2 model, but if that issue is solved, the AveragedModel of each GPU rank would contain averaged values of the corresponding shard. In theory, we should be able to construct the final averaged model by gathering the AveragedModel weights from all GPU ranks, right? This should be done before saving a checkpoint, or we'll end up saving only the shard of rank 0.

The other option is to call configure_model() without model_sharded_context, then construct the AveragedModel, and finally shard both models. Your suggesting that if we simply store the AveragedModel as a property of pl_module, it gets sharded automatically. Two questions come to my mind:

  1. Is the sharding of the averaged parameters guaranteed to be identical to the sharding of the original parameters?
  2. If you need to use FSDP, you probably want the averaged parameters to be on CPU, right? I think this would cause the averaged parameters to be stored on the same device as the original parameters.

@amorehead
Copy link
Contributor

amorehead commented Apr 24, 2025

@senarvi, for the time being, it seems like the reference implementation mentioned in this GitHub issue is a standardized (PyTorch-tested) way of running EMA weight updates with FSDP. At first glance, with this reference implementation, I'm not sure if the EMA weights can be stored on any device other than a GPU (for sake of sharding). Also, this GitHub issue makes me realize that no matter what, one will have to instantiate the full set of model weights, either to initialize the (full) EMA weights or to call FSDP.summon_full_params for EMA weight updates after every optimizer step. This means that, in my understanding, EMA necessarily will incur potentially costly GPU/CPU memory usage when gathering all model weights, so very large models probably won't be runnable even with PyTorch's official implementation. Surely someone must have already solved this issue, unless no one training foundation models is using EMA 😉

Update: It looks like DeepSpeed has implemented their own version of EMA for use with their Zero Stage 3 (model-parallel) training strategy. It seems they also run the equivalent of FSDP.summon_full_params to update the EMA model weights. Huh, maybe the all-gather operation isn't as memory-intensive as I'm thinking it may be.

@senarvi
Copy link
Contributor Author

senarvi commented Apr 25, 2025

@amorehead , we don't need to gather all the weights at once, right? To me it looks like DeepSpeed is gathering one parameter at a time:

for param, param_ema in zip(model.parameters(), model_ema.parameters()):
	params_to_fetch = _z3_params_to_fetch([param, param_ema]) if zero_stage == 3 else []
	should_gather_param = len(params_to_fetch) > 0
	with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
	    # Update param_ema using param.

But that would require a change in AveragedModel.update_parameters(). This could work:

params_to_fetch = _z3_params_to_fetch(model.parameters() + model_ema.parameters()) if zero_stage == 3 else []
should_gather_param = len(params_to_fetch) > 0
with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
	self._averaged_model.update_parameters(pl_module)

But I guess that would gather all the parameters at once.

I would be happy, at this point, to have something that works, even if it's not the most memory-efficient way. I just wouldn't want the code look like this:

if fsdp:
    with FSDP.summon_full_params(pl_module):
        pl_module._averaged_model.update_parameters(pl_module.current_model)
elif deepspeed:
    params_to_fetch = _z3_params_to_fetch(pl_module.parameters()) if zero_stage == 3 else []
    should_gather_param = len(params_to_fetch) > 0
    with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
	pl_module._averaged_model.update_parameters(pl_module.current_model)
elif ...:

Maybe we could add some hook and let the user decide what to do?

@yhao-z
Copy link

yhao-z commented Apr 25, 2025

When i use this WeightAveraging Callback along with ModelCheckpoint Callback, i noticed that save_last=True in ModelCheckpoint Callback fails. I don't realy know why, but it just happened. plz check, thx a lot.

@cyanic-selkie
Copy link

We have been using this patch in production for a few months without issues. It might be a good idea to merge it and let any other edge cases not mentioned here resolve themselves through bug reports—just my two cents.

@amorehead
Copy link
Contributor

@senarvi, I think it would be best practice to add or use a Lightning hook that accomplishes this weight gathering in a standardized way, then we can just reference that hook in this callback once its implemented. For the time being, just like @cyanic-selkie said, I think this PR is more than good enough as is to merge, since it works well for DDP training.

@amorehead
Copy link
Contributor

@yhao-z, make sure you are not using the configure_model hook with DDPStrategy. Otherwise, the WeightAveraging callback may fail to correctly wrap your original model weights and thus fail to save a last checkpoint.

Seppo Enarvi added 3 commits April 26, 2025 11:23
* Fixed a reference in a docstring.
* Removed two unit tests to avoid running out of memory in the CI pipeline.
@senarvi senarvi force-pushed the generic-weight-averaging branch from 729fd60 to 3dafb4c Compare April 26, 2025 09:03
@senarvi
Copy link
Contributor Author

senarvi commented Apr 26, 2025

I agree, better to do this in small steps and it's probably better that the support for sharded strategies is written by someone who actually uses them. But something like what @amorehead said, a standardized way to gather the weights sounds like best practice.

For now, I just wanted to make sure that if someone tries to use weight averaging with configure_model(), it won't fail silently and produce incorrect results. This is what I did:

  • setup() will call configure_model() if it's overridden, but issues a warning. Likely runs out of memory too.
  • I added a unit test that checks that the AveragedModel is constructed correctly if configure_model() is overridden, so it should work if it doesn't run out of memory. (The CI pipeline runs out of memory if I use two GPUs, so I couldn't include that.)
  • I added a note in the documentation that weight averaging doesn't work with sharded models.

@amorehead How does that sound to you?

@yhao-z was your issue also caused by overriding configure_model()? If not, please send me a minimal example that I can use to reproduce the error. I tested with the code below and it didn't fail.

import torch
from torch import nn
from torch.optim.swa_utils import get_ema_avg_fn
from torch.utils.data import DataLoader

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, WeightAveraging
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset

class EMATestModel(BoringModel):
    def __init__(self):
        super().__init__()
        layers = [nn.Linear(32, 32)]
        layers += [nn.ReLU(), nn.Linear(32, 2)]
        self.layer = nn.Sequential(*layers)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

class EMAWeightAveraging(WeightAveraging):
    def __init__(self):
        super().__init__(avg_fn=get_ema_avg_fn())

    def should_update(self, step_idx=None, epoch_idx=None):
        return (step_idx is not None) and (step_idx >= 10)


dataset = RandomDataset(32, 64)
dataloader = DataLoader(dataset, batch_size=2)
model = EMATestModel()
trainer = Trainer(callbacks=[EMAWeightAveraging(), ModelCheckpoint(dirpath=".", save_last=True)], max_epochs=10)
trainer.fit(model, dataloader)

@senarvi
Copy link
Contributor Author

senarvi commented Apr 26, 2025

@Borda if there's nothing else, all the checks passed and this could be a good time to merge.

@yhao-z
Copy link

yhao-z commented Apr 26, 2025

@senarvi sorry for the late reply, cuz i'm checking the reason why this happened. I apologize for reporting this wrongly. WeightAveraging Callback along with ModelCheckpoint Callback run quite well and without any bug.

The reason why i encountered this issue is because i'm using lightningCLI but i've failed to write a config.yaml properly. The WeightAveraging Callback overwrittes the ModelCheckpoint Callback in the config.yaml, thus save_last=True in the ModelCheckpoint is no longer in force.

sorry again for the confusion, and also i think this PR is more than good enough as is to merge

@amorehead
Copy link
Contributor

@senarvi, thanks for updating the PR accordingly! Yes, these changes look good to me and seem reasonable for this first version of the callback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docs Documentation related fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add feature Exponential Moving Average (EMA)
8 participants