Skip to content

Commit 42d91cd

Browse files
author
Seppo Enarvi
committed
Test that stopping and resuming won't make a difference in the final model
1 parent 51b9a06 commit 42d91cd

File tree

2 files changed

+84
-53
lines changed

2 files changed

+84
-53
lines changed

src/lightning/pytorch/callbacks/weight_averaging.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424
from torch import Tensor
2525
from torch.optim.swa_utils import AveragedModel
26+
from typing_extensions import override
2627

2728
import lightning.pytorch as pl
2829
from lightning.pytorch.callbacks.callback import Callback
@@ -50,10 +51,15 @@ class WeightAveraging(Callback):
5051

5152
def __init__(
5253
self,
53-
device: Optional[Union[torch.device, int]] = torch.device("cpu"),
54+
device: Optional[Union[torch.device, str, int]] = "cpu",
5455
avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None,
5556
):
56-
self._device = device
57+
# The default value is a string so that jsonargparse knows how to serialize it.
58+
if isinstance(device, str):
59+
self._device: Optional[Union[torch.device, int]] = torch.device(device)
60+
else:
61+
self._device = device
62+
5763
self._avg_fn = avg_fn
5864
self._average_model: Optional[AveragedModel] = None
5965

@@ -83,6 +89,7 @@ def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int]
8389
"""
8490
return step_idx is not None
8591

92+
@override
8693
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
8794
"""Called when fit, validate, test, predict, or tune begins.
8895
@@ -98,6 +105,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
98105
device = self._device or pl_module.device
99106
self._average_model = AveragedModel(model=pl_module, device=device, avg_fn=self._avg_fn, use_buffers=True)
100107

108+
@override
101109
def on_train_batch_end(
102110
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
103111
) -> None:
@@ -121,6 +129,7 @@ def on_train_batch_end(
121129
self._average_model.update_parameters(pl_module)
122130
self._latest_update_step = trainer.global_step
123131

132+
@override
124133
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
125134
"""Called when a training epoch ends.
126135
@@ -136,6 +145,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
136145
self._average_model.update_parameters(pl_module)
137146
self._latest_update_epoch = trainer.current_epoch
138147

148+
@override
139149
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
140150
"""Called when training ends.
141151
@@ -147,8 +157,10 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
147157
148158
"""
149159
assert self._average_model is not None
160+
rank_zero_info("Loading the average model parameters to the final model.")
150161
self._copy_average_to_current(pl_module)
151162

163+
@override
152164
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
153165
"""Called when a validation epoch begins.
154166
@@ -163,6 +175,7 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn
163175
rank_zero_info("Loading the average model parameters for validation.")
164176
self._swap_models(pl_module)
165177

178+
@override
166179
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
167180
"""Called when a validation epoch ends.
168181
@@ -177,6 +190,7 @@ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.Lightnin
177190
rank_zero_info("Recovering the current model parameters after validation.")
178191
self._swap_models(pl_module)
179192

193+
@override
180194
def state_dict(self) -> dict[str, Any]:
181195
"""Called when saving a checkpoint.
182196
@@ -188,6 +202,7 @@ def state_dict(self) -> dict[str, Any]:
188202
"""
189203
return {"latest_update_step": self._latest_update_step}
190204

205+
@override
191206
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
192207
"""Called when loading a checkpoint.
193208
@@ -199,6 +214,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
199214
"""
200215
self._latest_update_step = state_dict["latest_update_step"]
201216

217+
@override
202218
def on_save_checkpoint(
203219
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any]
204220
) -> None:
@@ -231,6 +247,7 @@ def on_save_checkpoint(
231247
name: value for name, value in average_model_state.items() if not name.startswith("module.")
232248
}
233249

250+
@override
234251
def on_load_checkpoint(
235252
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: dict[str, Any]
236253
) -> None:

tests/tests_pytorch/callbacks/test_weight_averaging.py

Lines changed: 65 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,43 +12,37 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from copy import deepcopy
1516
from pathlib import Path
1617
from typing import Any, Optional
1718

1819
import pytest
1920
import torch
2021
from torch import Tensor, nn
2122
from torch.optim.swa_utils import get_swa_avg_fn
22-
from torch.utils.data import DataLoader
23+
from torch.utils.data import DataLoader, Dataset
2324

2425
from lightning.pytorch import LightningModule, Trainer
2526
from lightning.pytorch.callbacks import WeightAveraging
2627
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
2728
from tests_pytorch.helpers.runif import RunIf
2829

2930

30-
class WeightAveragingTestModel(BoringModel):
31-
def __init__(
32-
self, batch_norm: bool = True, iterable_dataset: bool = False, crash_on_epoch: Optional[int] = None
33-
) -> None:
31+
class TestModel(BoringModel):
32+
def __init__(self, batch_norm: bool = True) -> None:
3433
super().__init__()
3534
layers = [nn.Linear(32, 32)]
3635
if batch_norm:
3736
layers.append(nn.BatchNorm1d(32))
3837
layers += [nn.ReLU(), nn.Linear(32, 2)]
3938
self.layer = nn.Sequential(*layers)
40-
self.iterable_dataset = iterable_dataset
41-
self.crash_on_epoch = crash_on_epoch
39+
self.crash_on_epoch = None
4240

4341
def training_step(self, batch: Tensor, batch_idx: int) -> None:
4442
if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch:
45-
raise Exception("CRASH TEST")
43+
raise Exception("CRASH")
4644
return super().training_step(batch, batch_idx)
4745

48-
def train_dataloader(self) -> None:
49-
dataset_class = RandomIterableDataset if self.iterable_dataset else RandomDataset
50-
return DataLoader(dataset_class(32, 32), batch_size=4)
51-
5246
def configure_optimizers(self) -> None:
5347
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
5448

@@ -194,95 +188,115 @@ def setup(self, trainer, pl_module, stage) -> None:
194188
@pytest.mark.parametrize("batch_norm", [True, False])
195189
@pytest.mark.parametrize("iterable_dataset", [True, False])
196190
def test_ema(tmp_path, batch_norm: bool, iterable_dataset: bool):
197-
_train(tmp_path, EMATestCallback(), batch_norm=batch_norm, iterable_dataset=iterable_dataset)
191+
model = TestModel(batch_norm=batch_norm)
192+
dataset = RandomIterableDataset(32, 32) if iterable_dataset else RandomDataset(32, 32)
193+
_train(model, dataset, tmp_path, EMATestCallback())
198194

199195

200196
@pytest.mark.parametrize(
201197
"accelerator", [pytest.param("gpu", marks=RunIf(min_cuda_gpus=1)), pytest.param("mps", marks=RunIf(mps=True))]
202198
)
203199
def test_ema_accelerator(tmp_path, accelerator):
204-
_train(tmp_path, EMATestCallback(), accelerator=accelerator, devices=1)
200+
model = TestModel()
201+
dataset = RandomDataset(32, 32)
202+
_train(model, dataset, tmp_path, EMATestCallback(), accelerator=accelerator, devices=1)
205203

206204

207205
@RunIf(min_cuda_gpus=2, standalone=True)
208206
def test_ema_ddp(tmp_path):
209-
_train(tmp_path, EMATestCallback(devices=2), strategy="ddp", accelerator="gpu", devices=2)
207+
model = TestModel()
208+
dataset = RandomDataset(32, 32)
209+
_train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp", accelerator="gpu", devices=2)
210210

211211

212212
@RunIf(min_cuda_gpus=2)
213213
def test_ema_ddp_spawn(tmp_path):
214-
_train(tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="gpu", devices=2)
214+
model = TestModel()
215+
dataset = RandomDataset(32, 32)
216+
_train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="gpu", devices=2)
215217

216218

217219
@RunIf(skip_windows=True)
218220
def test_ema_ddp_spawn_cpu(tmp_path):
219-
_train(tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="cpu", devices=2)
221+
model = TestModel()
222+
dataset = RandomDataset(32, 32)
223+
_train(model, dataset, tmp_path, EMATestCallback(devices=2), strategy="ddp_spawn", accelerator="cpu", devices=2)
220224

221225

222-
@pytest.mark.parametrize("crash_on_epoch", [1, 3])
226+
@pytest.mark.parametrize("crash_on_epoch", [1, 3, 5])
223227
def test_ema_resume(tmp_path, crash_on_epoch):
224-
_train_and_resume(tmp_path, crash_on_epoch=crash_on_epoch)
228+
dataset = RandomDataset(32, 32)
229+
model1 = TestModel()
230+
model2 = deepcopy(model1)
231+
232+
_train(model1, dataset, tmp_path, EMATestCallback())
233+
234+
model2.crash_on_epoch = crash_on_epoch
235+
model2 = _train_and_resume(model2, dataset, tmp_path)
236+
237+
for param1, param2 in zip(model1.parameters(), model2.parameters()):
238+
assert torch.allclose(param1, param2, atol=0.001)
225239

226240

227241
@RunIf(skip_windows=True)
228242
def test_ema_resume_ddp(tmp_path):
229-
_train_and_resume(tmp_path, crash_on_epoch=3, use_ddp=True)
243+
model = TestModel()
244+
model.crash_on_epoch = 3
245+
dataset = RandomDataset(32, 32)
246+
_train_and_resume(model, dataset, tmp_path, strategy="ddp_spawn", devices=2)
230247

231248

232249
def test_swa(tmp_path):
233-
_train(tmp_path, SWATestCallback())
250+
model = TestModel()
251+
dataset = RandomDataset(32, 32)
252+
_train(model, dataset, tmp_path, SWATestCallback())
234253

235254

236255
def _train(
256+
model: TestModel,
257+
dataset: Dataset,
237258
tmp_path: str,
238259
callback: WeightAveraging,
239-
batch_norm: bool = True,
240260
strategy: str = "auto",
241261
accelerator: str = "cpu",
242262
devices: int = 1,
243-
iterable_dataset: bool = False,
244263
checkpoint_path: Optional[str] = None,
245-
crash_on_epoch: Optional[int] = None,
246-
) -> None:
264+
will_crash: bool = False,
265+
) -> TestModel:
266+
deterministic = accelerator == "cpu"
247267
trainer = Trainer(
248-
default_root_dir=tmp_path,
249-
enable_progress_bar=False,
250-
enable_model_summary=False,
268+
accelerator=accelerator,
269+
strategy=strategy,
270+
devices=devices,
251271
logger=False,
272+
callbacks=callback,
252273
max_epochs=8,
253274
num_sanity_val_steps=0,
254-
callbacks=callback,
275+
enable_checkpointing=will_crash,
276+
enable_progress_bar=False,
277+
enable_model_summary=False,
255278
accumulate_grad_batches=2,
256-
strategy=strategy,
257-
accelerator=accelerator,
258-
devices=devices,
259-
)
260-
model = WeightAveragingTestModel(
261-
batch_norm=batch_norm, iterable_dataset=iterable_dataset, crash_on_epoch=crash_on_epoch
279+
deterministic=deterministic,
280+
default_root_dir=tmp_path,
262281
)
263-
264-
if crash_on_epoch is None:
265-
trainer.fit(model, ckpt_path=checkpoint_path)
282+
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
283+
if will_crash:
284+
with pytest.raises(Exception, match="CRASH"):
285+
trainer.fit(model, dataloader, ckpt_path=checkpoint_path)
266286
else:
267-
with pytest.raises(Exception, match="CRASH TEST"):
268-
trainer.fit(model, ckpt_path=checkpoint_path)
269-
287+
trainer.fit(model, dataloader, ckpt_path=checkpoint_path)
270288
assert trainer.lightning_module == model
271289

272290

273-
def _train_and_resume(tmp_path: str, crash_on_epoch: int, use_ddp: bool = False) -> None:
274-
strategy = "ddp_spawn" if use_ddp else "auto"
275-
devices = 2 if use_ddp else 1
276-
277-
_train(
278-
tmp_path, EMATestCallback(devices=devices), strategy=strategy, devices=devices, crash_on_epoch=crash_on_epoch
279-
)
291+
def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices: int = 1, **kwargs) -> TestModel:
292+
_train(model, dataset, tmp_path, EMATestCallback(devices=devices), devices=devices, will_crash=True, **kwargs)
280293

281294
checkpoint_dir = Path(tmp_path) / "checkpoints"
282295
checkpoint_names = os.listdir(checkpoint_dir)
283296
assert len(checkpoint_names) == 1
284297
checkpoint_path = str(checkpoint_dir / checkpoint_names[0])
285298

286-
_train(
287-
tmp_path, EMATestCallback(devices=devices), strategy=strategy, devices=devices, checkpoint_path=checkpoint_path
288-
)
299+
model = TestModel.load_from_checkpoint(checkpoint_path)
300+
callback = EMATestCallback(devices=devices)
301+
_train(model, dataset, tmp_path, callback, devices=devices, checkpoint_path=checkpoint_path, **kwargs)
302+
return model

0 commit comments

Comments
 (0)