Skip to content

Commit cb7f45f

Browse files
gl3lanpytorchmergebot
authored andcommitted
remove unnecessary sync point in AveragedModel update (pytorch#158017)
Summary: The test `bool(self.n_averaged == 0)` is a CPU/GPU synchronization point that is called for each update. This test is only meant to know whether the AveragedModel copy has been initialized or not. This diff introduces a CPU-based variable for that purpose. When loading from checkpoint we also make sure the parameter is refreshed. After this fix, each `update_parameter` call is reduced to 6ms from 333ms (98% reduction). Test Plan: contbuild & OSS CI Test plan from GitHub: CI Rollback Plan: Differential Revision: D78074709 Pull Request resolved: pytorch#158017 Approved by: https://github.com/janeyx99
1 parent 5937861 commit cb7f45f

File tree

2 files changed

+78
-13
lines changed

2 files changed

+78
-13
lines changed

test/optim/test_swa_utils.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def _test_averaged_model(self, net_device, swa_device, ema):
7676
# Check that AveragedModel is on the correct device
7777
self.assertTrue(p_swa.device == swa_device)
7878
self.assertTrue(p_avg.device == net_device)
79-
self.assertTrue(averaged_dnn.n_averaged.device == swa_device)
8079

8180
def _run_averaged_steps(self, dnn, swa_device, ema):
8281
ema_decay = 0.999
@@ -150,6 +149,44 @@ def test_averaged_model_state_dict(self):
150149
self.assertEqual(p_swa, p_swa2)
151150
self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)
152151

152+
def test_averaged_model_backward_compatibility(self):
153+
"""Test that AveragedModel correctly handles old checkpoints with tensor n_averaged."""
154+
dnn = torch.nn.Sequential(
155+
torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)
156+
)
157+
averaged_dnn = AveragedModel(dnn)
158+
159+
# Update parameters a few times
160+
n_updates = 5
161+
for _ in range(n_updates):
162+
for p in dnn.parameters():
163+
p.detach().add_(torch.randn_like(p))
164+
averaged_dnn.update_parameters(dnn)
165+
166+
# Manually create a state dict with tensor n_averaged (simulating old checkpoint)
167+
state_dict = averaged_dnn.state_dict()
168+
# Create an old-style tensor n_averaged
169+
old_n_averaged = torch.tensor(n_updates, dtype=torch.long)
170+
state_dict["n_averaged"] = old_n_averaged
171+
172+
# Create new model and load the old-style state dict
173+
averaged_dnn2 = AveragedModel(dnn)
174+
averaged_dnn2.load_state_dict(state_dict)
175+
176+
# Check that n_averaged was correctly loaded as a Python int
177+
self.assertEqual(averaged_dnn2.n_averaged, n_updates)
178+
self.assertIsInstance(averaged_dnn2.n_averaged, int)
179+
180+
# Verify that parameters are correctly loaded
181+
for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()):
182+
self.assertEqual(p_swa, p_swa2)
183+
184+
# Test that we can continue to update parameters without issues
185+
for p in dnn.parameters():
186+
p.detach().add_(torch.randn_like(p))
187+
averaged_dnn2.update_parameters(dnn)
188+
self.assertEqual(averaged_dnn2.n_averaged, n_updates + 1)
189+
153190
def test_averaged_model_default_avg_fn_picklable(self):
154191
dnn = torch.nn.Sequential(
155192
torch.nn.Conv2d(1, 5, kernel_size=3),

torch/optim/swa_utils.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,28 @@ def swa_update(
116116
return swa_update
117117

118118

119+
def _load_state_dict_pre_hook(
120+
module,
121+
state_dict,
122+
prefix,
123+
local_metadata,
124+
strict,
125+
missing_keys,
126+
unexpected_keys,
127+
error_msgs,
128+
):
129+
"""Pre-hook to handle backward compatibility with tensor n_averaged."""
130+
# Check if the old tensor n_averaged is present in the state dict
131+
n_averaged_key = prefix + "n_averaged"
132+
if n_averaged_key in state_dict:
133+
# Convert tensor n_averaged to Python int for backward compatibility
134+
n_averaged_tensor = state_dict[n_averaged_key]
135+
if isinstance(n_averaged_tensor, Tensor):
136+
module.n_averaged = int(n_averaged_tensor.item())
137+
# Remove the old tensor buffer from state_dict to avoid loading it
138+
del state_dict[n_averaged_key]
139+
140+
119141
class AveragedModel(Module):
120142
r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
121143
@@ -215,7 +237,7 @@ class AveragedModel(Module):
215237
https://paperswithcode.com/method/polyak-averaging
216238
"""
217239

218-
n_averaged: Tensor
240+
n_averaged: int
219241

220242
def __init__(
221243
self,
@@ -234,17 +256,25 @@ def __init__(
234256
self.module = deepcopy(model)
235257
if device is not None:
236258
self.module = self.module.to(device)
237-
self.register_buffer(
238-
"n_averaged", torch.tensor(0, dtype=torch.long, device=device)
239-
)
259+
self.n_averaged = 0
240260
self.avg_fn = avg_fn
241261
self.multi_avg_fn = multi_avg_fn
242262
self.use_buffers = use_buffers
263+
self.register_load_state_dict_pre_hook(_load_state_dict_pre_hook)
243264

244265
def forward(self, *args, **kwargs):
245266
"""Forward pass."""
246267
return self.module(*args, **kwargs)
247268

269+
def get_extra_state(self) -> Any:
270+
"""Get extra state for serialization."""
271+
return {"n_averaged": self.n_averaged}
272+
273+
def set_extra_state(self, state: Any) -> None:
274+
"""Set extra state from deserialization."""
275+
if isinstance(state, dict) and "n_averaged" in state:
276+
self.n_averaged = state["n_averaged"]
277+
248278
def update_parameters(self, model: Module):
249279
"""Update model parameters."""
250280
self_param = (
@@ -280,28 +310,26 @@ def update_parameters(self, model: Module):
280310
self.multi_avg_fn(
281311
self_params, # type: ignore[arg-type]
282312
model_params, # type: ignore[arg-type]
283-
self.n_averaged.to(device),
313+
self.n_averaged,
284314
)
285315
elif (
286316
device is not None
287317
and device.type in _get_foreach_kernels_supported_devices()
288318
):
289319
multi_avg_fn = get_swa_multi_avg_fn()
290-
multi_avg_fn(
291-
self_params, model_params, self.n_averaged.to(device)
292-
)
320+
multi_avg_fn(self_params, model_params, self.n_averaged)
293321
else:
294322
avg_fn = get_swa_avg_fn()
295-
n_averaged = self.n_averaged.to(device)
296323
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
297-
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
324+
p_averaged.copy_(
325+
avg_fn(p_averaged, p_model, self.n_averaged)
326+
)
298327
else:
299328
for p_averaged, p_model in zip( # type: ignore[assignment]
300329
self_param_detached, model_param_detached
301330
):
302-
n_averaged = self.n_averaged.to(p_averaged.device)
303331
p_averaged.detach().copy_(
304-
self.avg_fn(p_averaged.detach(), p_model, n_averaged)
332+
self.avg_fn(p_averaged.detach(), p_model, self.n_averaged)
305333
)
306334

307335
if not self.use_buffers:

0 commit comments

Comments
 (0)