-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Generic weight averaging callback that supports EMA #20545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Hey @senarvi, this looks great! I saw you already added support for saving and resuming which is great. There are many scenarios there (save every n steps, time-based, every epoch, etc) let's make sure we cover them all (for inspiration, we added quite a few tests here #20379)
No I think it's better to have one with configurable averaging flags, more lightning-esque
I think this is ok, but my doubt with forcing Wdyt about this? I don't necessarily want to make the implementation more complex, so this is just for discussion.
It would be nice to make it configurable, and probably users will want to get to some minimum and then start averaging. The criteria to do so may be very bespoke, so maybe allowing the user to implement a custom hook to decide whether to start averaging or whether to average at a given step would be super handy. Otherwise I'm expecting users will train for some time, save a checkpoint, then reload with this callback added to the trainer and start averaging. Which is totally fine but it requires you to stop and resume. Regarding removing the StochasticWeightAveraging callback, I don't necessarily see that happening. We have a pretty strong commitment to backward compatibility at this point, so keeping that in with a notice to just use this one will not hurt. |
That's a good point. I don't know what would be a good solution.
That's an interesting idea. We could have the user pass a function It seems that AveragedModel will copy the current model parameters when called the first time, and update the average on subsequent calls. This means that the first average is computed when I checked how StochasticWeightAveraging does this and I think it doesn't work correctly. It only ever updates the average model parameters in on_train_epoch_start(), so the average is not updated after the last epoch. Just shows why I'd like to keep the logic as simple as possible. |
Hi, I have a couple questions.
|
During training (stage=fit), the actual LightningModule is what we update using the optimizer (I call it the current model) and an AveragedModel is maintained in the background (I call it the average model). I assume that validation is only called during training. Before and after validation we swap the current model and the average model, so the average model will be validated. When saving a checkpoint, we save the average model parameters in the state_dict. So if you later load the checkpoint without WeightAveraging callback and run a test or export to ONNX, you will be using the average parameters. When training ends, we copy the average model parameters to the current model. So if you run a test or export to ONNX after training, you will be using the average parameters. That's the idea at least. I'm not confident that I have thought about every possible corner case. It would be great if you could test that it works in your case. |
@senarvi Ah! Thanks for the clarification, I should've checked the code out more carefully. I tried your branch out on a quantization aware training enabled model with ONNX export at the end and everything is working beautifully! I hope this gets merged quickly. |
efc77dc
to
0010492
Compare
The user can now provide either the For example: update_on_step = lambda x: x > 100 or update_on_epoch = lambda x: x in (3, 5, 7) Using I tested EMA in an actual learning task and it gave an improvement, so I'm starting to be more confident that this works. I think the biggest question that is still left is whether it's a problem that we force
@tchaton I think you contributed the |
* A callback that updates a torch.optim.swa_utils.AveragedModel after specific steps or epochs. * The user can provide a callback that defines after which steps or epochs the average model is updated.
5f34205
to
c8d50bd
Compare
Is there anything blocking this from being merged? |
I marked this ready for review. There were no comments whether it's a problem that we force |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20545 +/- ##
=========================================
- Coverage 87% 79% -8%
=========================================
Files 268 266 -2
Lines 23449 23488 +39
=========================================
- Hits 20389 18475 -1914
- Misses 3060 5013 +1953 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solid contribution @senarvi! I added a few comments (most are quick to address, let me know what you can do here vs follow up PR), but overall looks great.
checkpoint["state_dict"] = { | ||
name[7:]: value for name, value in average_model_state.items() if name.startswith("module.") | ||
} | ||
checkpoint["averaging_state"] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checkpoint["averaging_state"] = { | |
checkpoint["averaged_state"] = { |
I get that it might be still "averaging" : ), but it's in fact "averaged" up to the current iterations. We can called it "average" model if it sounds better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lantiga the name is a bit confusing, but it means the state of the averaging process, not the average model. This includes the state variables of the AveragedModel
class, excluding the module
(i.e. n_averaged
). The average model is saved in state_dict
, so whatever we'll do with the checkpoint, we'll use the average model. The current model state is saved in current_model_state
, so that we can continue training with the WeightAveraging callback from the previous state. If you have a less confusing name for the "averaging state variables, excluding the averaged model parameters", I can change it.
|
||
""" | ||
if self._average_model is None: | ||
raise Exception("Trying to load a checkpoint, but no average model (outside fit). Don't know what to do.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is hard to understand for a user if they don't know the details of the callback.
raise Exception("Trying to load a checkpoint, but no average model (outside fit). Don't know what to do.") | |
raise Exception("Trying to load a checkpoint using the WeightAveraging callback outside the `fit` stage. The WeightAveraging callback can only be used in the `fit` stage.") |
I'm wondering: instead of raising we could just load the average model e.g. for predict. This will avoid forcing users to remove the callback from the Trainer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lantiga I guess I just wasn't sure in which situation this callback would be called outside fit, but yes, if the user calls Trainer.validate/test/predict(ckpt_path=...)
, I believe this will be called and the best thing to do would be to load the average model. The average model will be loaded if we don't do anything. Maybe just display a warning in that case.
I guess on_save_checkpoint()
can also be called outside fit - if the user calls Trainer.save_checkpoint()
after training. In that case we also don't have to do anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lantiga This is what I did. Please check if you think the messages are clear now.
assert trainer.lightning_module == model | ||
|
||
|
||
def _train_and_resume(tmp_path: str, crash_on_epoch: int, use_ddp: bool = False) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tests that we can crash and resume, but afaict it doesn't test whether the resulting averaging is equivalent. We can harden this in a subsequent PR, but it is important to know for sure that averaging works if I stop training and resume while averages are being taken, irrespective of where I stop and resume in the lifecycle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll try to still improve the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lantiga now I test that after stopping and resuming we get the same final model. The parameters are not identical - I have to use atol=0.001
- but they are close enough so that I think that the difference comes from some random change instead of a bug in restoring the checkpoint. I don't know what could cause the difference, though. I pass deterministic=True
to Trainer. I'm curious if you have some ideas, or if you think that that's close enough.
BTW: I think it's totally fine to merge this as is and open an issue to gather discussions about averaging buffers. The other question I have (for the future) is related to fitting both models on GPU. It may make sense to give the ability to keep the AveragedModel on a different device (e.g. |
There's a |
c6856eb
to
42d91cd
Compare
Hi! Thanks for this great PR. The current implementation only leverages |
I think we could just pass |
- The user can specify when to update the average model by overriding the should_update() method - Any keyword arguments will be passed to the AveragedModel constructor
8953a18
to
822231f
Compare
cf34483
to
5deb0bb
Compare
Thank you @catalpaaa for checking the PR! I removed the logging. I think it's not needed anymore. Originally, my biggest fear was that the weights are not transferred correctly, which could go unnoticed. And actually... I was kind of right. The
I'm super happy that I found this bug. Now, I think, the only test that's failing is |
When training with TPUs, I add callbacks.append(WeightAveraging(avg_fn=get_ema_avg_fn(0.9999))) to my call backs. I will get the following errors:
The code runs fine with single GPU. If I pull the decay function out from pytorch's @torch.no_grad()
def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
decay = 0.9999
return decay * ema_param + (1 - decay) * current_param
callbacks.append(WeightAveraging(avg_fn=ema_update)) Everything runs fine. Please let me know if I'm using this wrong. |
Thanks @catalpaaa . You're using it correctly. I noticed the same thing with ddp_spawn, so I had to use a similar workaround in the unit tests. The problem seems to be caused by |
I guess pickle just hates function in function :( |
*Another respectful bump |
@senarvi, while testing out your callback in a local codebase of mine, I discovered an edge case that should be simple to address. Namely, when one is using a @override
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
"""Called when fit, validate, test, predict, or tune begins.
Creates an `AveragedModel` when fit begins.
Args:
trainer: The current `~lightning.pytorch.trainer.trainer.Trainer` instance.
pl_module: The current `~lightning.pytorch.core.LightningModule` instance.
stage: The `~lightning.pytorch.trainer.trainer.Trainer` state.
"""
if stage == "fit":
device = self._device or pl_module.device
pl_module.configure_model() # add this to make sure the model is wrapped correctly
self._average_model = AveragedModel(
model=pl_module,
device=device,
use_buffers=self._use_buffers,
**self._kwargs,
) |
Also, I've filled in the import torch
from torch.optim.swa_utils import get_ema_avg_fn
from typing import Optional, Union
class EMAWeightAveraging(WeightAveraging):
"""Exponential Moving Average (EMA) Weight Averaging callback."""
def __init__(
self,
device: Optional[Union[torch.device, str, int]] = "cpu",
use_buffers: bool = True,
decay: float = 0.999,
update_every_n_steps: int = 1,
update_starting_at_step: Optional[int] = None,
update_starting_at_epoch: Optional[int] = None,
**kwargs: Any,
):
super().__init__(
device=device,
use_buffers=use_buffers,
**kwargs,
avg_fn=get_ema_avg_fn(decay=decay),
)
self.update_every_n_steps = update_every_n_steps
self.update_starting_at_step = update_starting_at_step
self.update_starting_at_epoch = update_starting_at_epoch
def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None):
"""Decide when to update the model weights.
Args:
step_idx: The current step index.
epoch_idx: The current epoch index.
Returns:
bool: True if the model weights should be updated, False otherwise.
"""
if step_idx is not None:
# Check step-based conditions only if we have a valid step_idx
meets_step_requirement = (
self.update_starting_at_step is None
or step_idx >= self.update_starting_at_step
)
meets_step_frequency = (
self.update_every_n_steps > 0
and step_idx % self.update_every_n_steps == 0
)
if meets_step_requirement and meets_step_frequency:
return True
if epoch_idx is not None:
# Check epoch-based condition only if we specify one
meets_epoch_requirement = (
self.update_starting_at_epoch is not None
and epoch_idx >= self.update_starting_at_epoch
)
if meets_epoch_requirement:
return True
return False |
That's a very good catch @amorehead ! Thanks for testing this before it's merged. I wasn't familiar with the with (
trainer.strategy.tensor_init_context(),
trainer.strategy.model_sharded_context(),
trainer.precision_plugin.module_init_context(),
):
_call_lightning_module_hook(trainer, "configure_model") Should we call |
@senarvi, you are right! I just noticed that the model will not be sharded unless Lightning's private method
To implement approach 2 above, we would need to (1) keep the call to Update: On further reflection, I think the only way to make this callback compatible with model-parallel training would be to store |
@amorehead , if I understand correctly, each GPU rank is running in its own process that contains a portion (shard) of the model weights and an The other option is to call
|
@senarvi, for the time being, it seems like the reference implementation mentioned in this GitHub issue is a standardized (PyTorch-tested) way of running EMA weight updates with FSDP. At first glance, with this reference implementation, I'm not sure if the EMA weights can be stored on any device other than a GPU (for sake of sharding). Also, this GitHub issue makes me realize that no matter what, one will have to instantiate the full set of model weights, either to initialize the (full) EMA weights or to call Update: It looks like DeepSpeed has implemented their own version of EMA for use with their Zero Stage 3 (model-parallel) training strategy. It seems they also run the equivalent of |
@amorehead , we don't need to gather all the weights at once, right? To me it looks like DeepSpeed is gathering one parameter at a time: for param, param_ema in zip(model.parameters(), model_ema.parameters()):
params_to_fetch = _z3_params_to_fetch([param, param_ema]) if zero_stage == 3 else []
should_gather_param = len(params_to_fetch) > 0
with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
# Update param_ema using param. But that would require a change in params_to_fetch = _z3_params_to_fetch(model.parameters() + model_ema.parameters()) if zero_stage == 3 else []
should_gather_param = len(params_to_fetch) > 0
with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
self._averaged_model.update_parameters(pl_module) But I guess that would gather all the parameters at once. I would be happy, at this point, to have something that works, even if it's not the most memory-efficient way. I just wouldn't want the code look like this: if fsdp:
with FSDP.summon_full_params(pl_module):
pl_module._averaged_model.update_parameters(pl_module.current_model)
elif deepspeed:
params_to_fetch = _z3_params_to_fetch(pl_module.parameters()) if zero_stage == 3 else []
should_gather_param = len(params_to_fetch) > 0
with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=should_gather_param):
pl_module._averaged_model.update_parameters(pl_module.current_model)
elif ...: Maybe we could add some hook and let the user decide what to do? |
When i use this WeightAveraging Callback along with ModelCheckpoint Callback, i noticed that |
We have been using this patch in production for a few months without issues. It might be a good idea to merge it and let any other edge cases not mentioned here resolve themselves through bug reports—just my two cents. |
@senarvi, I think it would be best practice to add or use a Lightning hook that accomplishes this weight gathering in a standardized way, then we can just reference that hook in this callback once its implemented. For the time being, just like @cyanic-selkie said, I think this PR is more than good enough as is to merge, since it works well for DDP training. |
@yhao-z, make sure you are not using the |
* Fixed a reference in a docstring. * Removed two unit tests to avoid running out of memory in the CI pipeline.
729fd60
to
3dafb4c
Compare
I agree, better to do this in small steps and it's probably better that the support for sharded strategies is written by someone who actually uses them. But something like what @amorehead said, a standardized way to gather the weights sounds like best practice. For now, I just wanted to make sure that if someone tries to use weight averaging with
@amorehead How does that sound to you? @yhao-z was your issue also caused by overriding import torch
from torch import nn
from torch.optim.swa_utils import get_ema_avg_fn
from torch.utils.data import DataLoader
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, WeightAveraging
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
class EMATestModel(BoringModel):
def __init__(self):
super().__init__()
layers = [nn.Linear(32, 32)]
layers += [nn.ReLU(), nn.Linear(32, 2)]
self.layer = nn.Sequential(*layers)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
class EMAWeightAveraging(WeightAveraging):
def __init__(self):
super().__init__(avg_fn=get_ema_avg_fn())
def should_update(self, step_idx=None, epoch_idx=None):
return (step_idx is not None) and (step_idx >= 10)
dataset = RandomDataset(32, 64)
dataloader = DataLoader(dataset, batch_size=2)
model = EMATestModel()
trainer = Trainer(callbacks=[EMAWeightAveraging(), ModelCheckpoint(dirpath=".", save_last=True)], max_epochs=10)
trainer.fit(model, dataloader) |
@Borda if there's nothing else, all the checks passed and this could be a good time to merge. |
@senarvi sorry for the late reply, cuz i'm checking the reason why this happened. I apologize for reporting this wrongly. WeightAveraging Callback along with ModelCheckpoint Callback run quite well and without any bug.
sorry again for the confusion, and also i think this PR is more than good enough as is to merge |
@senarvi, thanks for updating the PR accordingly! Yes, these changes look good to me and seem reasonable for this first version of the callback. |
A callback that updates an AveragedModel after every training step
What does this PR do?
This is similar to the existing StochasticWeightAveraging callback, but wraps the AveragedModel class from PyTorch. Reduced code duplication means easier maintenance. Also, any averaging function can be used. By default, the callback does averaging on every step, but this can be customized by overriding the
should_update(step_idx, epoch_idx)
method.Fixes #10914
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
📚 Documentation preview 📚: https://pytorch-lightning--20545.org.readthedocs.build/en/20545/