Skip to content

Commit 9e5d84d

Browse files
mibaumgartnercarmocca
authored andcommitted
Enforce an epoch scheduler interval when using SWA (#6588)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent bb4fd7e commit 9e5d84d

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

pytorch_lightning/callbacks/swa.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,14 +189,15 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo
189189
anneal_strategy=self._annealing_strategy,
190190
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
191191
)
192+
_scheduler_config = _get_default_scheduler_config()
193+
assert _scheduler_config["interval"] == "epoch" and _scheduler_config["frequency"] == 1
194+
_scheduler_config["scheduler"] = self._swa_scheduler
192195

193196
if trainer.lr_schedulers:
194197
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
195198
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
196-
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
199+
trainer.lr_schedulers[0] = _scheduler_config
197200
else:
198-
_scheduler_config = _get_default_scheduler_config()
199-
_scheduler_config["scheduler"] = self._swa_scheduler
200201
trainer.lr_schedulers.append(_scheduler_config)
201202

202203
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)

tests/callbacks/test_swa.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,22 @@
2424
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626
from tests.helpers import BoringModel, RandomDataset
27+
from tests.helpers.runif import RunIf
2728

2829
if _TORCH_GREATER_EQUAL_1_6:
2930
from pytorch_lightning.callbacks import StochasticWeightAveraging
31+
from torch.optim.swa_utils import SWALR
3032

3133
class SwaTestModel(BoringModel):
3234

33-
def __init__(self, batchnorm: bool = True):
35+
def __init__(self, batchnorm: bool = True, interval: str = "epoch"):
3436
super().__init__()
3537
layers = [nn.Linear(32, 32)]
3638
if batchnorm:
3739
layers.append(nn.BatchNorm1d(32))
3840
layers += [nn.ReLU(), nn.Linear(32, 2)]
3941
self.layer = nn.Sequential(*layers)
42+
self.interval = interval
4043

4144
def training_step(self, batch, batch_idx):
4245
output = self.forward(batch)
@@ -46,6 +49,14 @@ def training_step(self, batch, batch_idx):
4649
def train_dataloader(self):
4750
return DataLoader(RandomDataset(32, 64), batch_size=2)
4851

52+
def configure_optimizers(self):
53+
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
54+
return {
55+
"optimizer": optimizer,
56+
"scheduler": torch.optim.lr_scheduler.StepLR(optimizer, step_size=1),
57+
"interval": self.interval,
58+
}
59+
4960
class SwaTestCallback(StochasticWeightAveraging):
5061
update_parameters_calls: int = 0
5162
transfer_weights_calls: int = 0
@@ -61,6 +72,10 @@ def transfer_weights(self, *args, **kwargs):
6172
def on_train_epoch_start(self, trainer, *args):
6273
super().on_train_epoch_start(trainer, *args)
6374
assert trainer.train_loop._skip_backward == (trainer.current_epoch > self.swa_end)
75+
if self.swa_start <= trainer.current_epoch:
76+
assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR)
77+
assert trainer.lr_schedulers[0]["interval"] == "epoch"
78+
assert trainer.lr_schedulers[0]["frequency"] == 1
6479

6580
def on_train_epoch_end(self, trainer, *args):
6681
super().on_train_epoch_end(trainer, *args)
@@ -89,8 +104,8 @@ def on_train_end(self, trainer, pl_module):
89104

90105

91106
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
92-
def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1):
93-
model = SwaTestModel(batchnorm=batchnorm)
107+
def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, interval="epoch"):
108+
model = SwaTestModel(batchnorm=batchnorm, interval=interval)
94109
swa_start = 2
95110
max_epochs = 5
96111
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1)
@@ -147,7 +162,13 @@ def test_swa_callback(tmpdir, batchnorm):
147162
train_with_swa(tmpdir, batchnorm=batchnorm)
148163

149164

150-
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
165+
@RunIf(min_torch="1.6.0")
166+
@pytest.mark.parametrize("interval", ("epoch", "step"))
167+
def test_swa_callback_scheduler_step(tmpdir, interval: bool):
168+
train_with_swa(tmpdir, interval=interval)
169+
170+
171+
@RunIf(min_torch="1.6.0")
151172
def test_swa_raises():
152173
with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
153174
StochasticWeightAveraging(swa_epoch_start=0, swa_lrs=0.1)

0 commit comments

Comments
 (0)