Skip to content

Commit a541974

Browse files
Revert "remove unnecessary sync point in AveragedModel update (pytorch#158017)"
This reverts commit cb7f45f. Reverted pytorch#158017 on behalf of https://github.com/wdvr due to discussed with author - expecting this to break checkpointing ([comment](pytorch#158017 (comment)))
1 parent a63221a commit a541974

File tree

2 files changed

+13
-78
lines changed

2 files changed

+13
-78
lines changed

test/optim/test_swa_utils.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def _test_averaged_model(self, net_device, swa_device, ema):
7676
# Check that AveragedModel is on the correct device
7777
self.assertTrue(p_swa.device == swa_device)
7878
self.assertTrue(p_avg.device == net_device)
79+
self.assertTrue(averaged_dnn.n_averaged.device == swa_device)
7980

8081
def _run_averaged_steps(self, dnn, swa_device, ema):
8182
ema_decay = 0.999
@@ -149,44 +150,6 @@ def test_averaged_model_state_dict(self):
149150
self.assertEqual(p_swa, p_swa2)
150151
self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)
151152

152-
def test_averaged_model_backward_compatibility(self):
153-
"""Test that AveragedModel correctly handles old checkpoints with tensor n_averaged."""
154-
dnn = torch.nn.Sequential(
155-
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)
156-
)
157-
averaged_dnn = AveragedModel(dnn)
158-
159-
# Update parameters a few times
160-
n_updates = 5
161-
for _ in range(n_updates):
162-
for p in dnn.parameters():
163-
p.detach().add_(torch.randn_like(p))
164-
averaged_dnn.update_parameters(dnn)
165-
166-
# Manually create a state dict with tensor n_averaged (simulating old checkpoint)
167-
state_dict = averaged_dnn.state_dict()
168-
# Create an old-style tensor n_averaged
169-
old_n_averaged = torch.tensor(n_updates, dtype=torch.long)
170-
state_dict["n_averaged"] = old_n_averaged
171-
172-
# Create new model and load the old-style state dict
173-
averaged_dnn2 = AveragedModel(dnn)
174-
averaged_dnn2.load_state_dict(state_dict)
175-
176-
# Check that n_averaged was correctly loaded as a Python int
177-
self.assertEqual(averaged_dnn2.n_averaged, n_updates)
178-
self.assertIsInstance(averaged_dnn2.n_averaged, int)
179-
180-
# Verify that parameters are correctly loaded
181-
for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()):
182-
self.assertEqual(p_swa, p_swa2)
183-
184-
# Test that we can continue to update parameters without issues
185-
for p in dnn.parameters():
186-
p.detach().add_(torch.randn_like(p))
187-
averaged_dnn2.update_parameters(dnn)
188-
self.assertEqual(averaged_dnn2.n_averaged, n_updates + 1)
189-
190153
def test_averaged_model_default_avg_fn_picklable(self):
191154
dnn = torch.nn.Sequential(
192155
torch.nn.Conv2d(1, 5, kernel_size=3),

torch/optim/swa_utils.py

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -116,28 +116,6 @@ def swa_update(
116116
return swa_update
117117

118118

119-
def _load_state_dict_pre_hook(
120-
module,
121-
state_dict,
122-
prefix,
123-
local_metadata,
124-
strict,
125-
missing_keys,
126-
unexpected_keys,
127-
error_msgs,
128-
):
129-
"""Pre-hook to handle backward compatibility with tensor n_averaged."""
130-
# Check if the old tensor n_averaged is present in the state dict
131-
n_averaged_key = prefix + "n_averaged"
132-
if n_averaged_key in state_dict:
133-
# Convert tensor n_averaged to Python int for backward compatibility
134-
n_averaged_tensor = state_dict[n_averaged_key]
135-
if isinstance(n_averaged_tensor, Tensor):
136-
module.n_averaged = int(n_averaged_tensor.item())
137-
# Remove the old tensor buffer from state_dict to avoid loading it
138-
del state_dict[n_averaged_key]
139-
140-
141119
class AveragedModel(Module):
142120
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
143121
@@ -237,7 +215,7 @@ class AveragedModel(Module):
237215
https://paperswithcode.com/method/polyak-averaging
238216
"""
239217

240-
n_averaged: int
218+
n_averaged: Tensor
241219

242220
def __init__(
243221
self,
@@ -256,25 +234,17 @@ def __init__(
256234
self.module = deepcopy(model)
257235
if device is not None:
258236
self.module = self.module.to(device)
259-
self.n_averaged = 0
237+
self.register_buffer(
238+
"n_averaged", torch.tensor(0, dtype=torch.long, device=device)
239+
)
260240
self.avg_fn = avg_fn
261241
self.multi_avg_fn = multi_avg_fn
262242
self.use_buffers = use_buffers
263-
self.register_load_state_dict_pre_hook(_load_state_dict_pre_hook)
264243

265244
def forward(self, *args, **kwargs):
266245
"""Forward pass."""
267246
return self.module(*args, **kwargs)
268247

269-
def get_extra_state(self) -> Any:
270-
"""Get extra state for serialization."""
271-
return {"n_averaged": self.n_averaged}
272-
273-
def set_extra_state(self, state: Any) -> None:
274-
"""Set extra state from deserialization."""
275-
if isinstance(state, dict) and "n_averaged" in state:
276-
self.n_averaged = state["n_averaged"]
277-
278248
def update_parameters(self, model: Module):
279249
"""Update model parameters."""
280250
self_param = (
@@ -310,26 +280,28 @@ def update_parameters(self, model: Module):
310280
self.multi_avg_fn(
311281
self_params, # type: ignore[arg-type]
312282
model_params, # type: ignore[arg-type]
313-
self.n_averaged,
283+
self.n_averaged.to(device),
314284
)
315285
elif (
316286
device is not None
317287
and device.type in _get_foreach_kernels_supported_devices()
318288
):
319289
multi_avg_fn = get_swa_multi_avg_fn()
320-
multi_avg_fn(self_params, model_params, self.n_averaged)
290+
multi_avg_fn(
291+
self_params, model_params, self.n_averaged.to(device)
292+
)
321293
else:
322294
avg_fn = get_swa_avg_fn()
295+
n_averaged = self.n_averaged.to(device)
323296
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
324-
p_averaged.copy_(
325-
avg_fn(p_averaged, p_model, self.n_averaged)
326-
)
297+
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
327298
else:
328299
for p_averaged, p_model in zip( # type: ignore[assignment]
329300
self_param_detached, model_param_detached
330301
):
302+
n_averaged = self.n_averaged.to(p_averaged.device)
331303
p_averaged.detach().copy_(
332-
self.avg_fn(p_averaged.detach(), p_model, self.n_averaged)
304+
self.avg_fn(p_averaged.detach(), p_model, n_averaged)
333305
)
334306

335307
if not self.use_buffers:

0 commit comments

Comments
 (0)