Skip to content

Commit f944e62

Browse files
authored
Merge branch 'master' into update/tpu
2 parents 851a8d9 + 6dc1078 commit f944e62

File tree

6 files changed

+48
-16
lines changed

6 files changed

+48
-16
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
170170

171171
### Fixed
172172

173+
- Sanitize `None` params during pruning ([#6836](https://github.com/PyTorchLightning/pytorch-lightning/pull/6836))
174+
175+
173176
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
174177

175178

@@ -200,6 +203,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
200203
- Fixed TPU Colab hang issue, post training ([#6816](https://github.com/PyTorchLightning/pytorch-lightning/pull/6816))
201204

202205

206+
- Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588))
207+
208+
203209
- Fixed an issue with `IterableDataset` when `__len__` is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828))
204210

205211

pytorch_lightning/callbacks/finetuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
7777
# When `current_epoch` is 10, feature_extractor will start training.
7878
if current_epoch == self._unfreeze_at_epoch:
7979
self.unfreeze_and_add_param_group(
80-
module=pl_module.feature_extractor,
80+
modules=pl_module.feature_extractor,
8181
optimizer=optimizer,
8282
train_bn=True,
8383
)

pytorch_lightning/callbacks/pruning.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,9 @@ def sanitize_parameters_to_prune(
422422
current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)]
423423

424424
if parameters_to_prune is None:
425-
parameters_to_prune = [(m, p) for p in parameters for m in current_modules if hasattr(m, p)]
425+
parameters_to_prune = [
426+
(m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None
427+
]
426428
elif (
427429
isinstance(parameters_to_prune, (list, tuple)) and len(parameters_to_prune) > 0
428430
and all(len(p) == 2 for p in parameters_to_prune)

pytorch_lightning/callbacks/stochastic_weight_avg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,15 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo
187187
anneal_strategy=self._annealing_strategy,
188188
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
189189
)
190+
_scheduler_config = _get_default_scheduler_config()
191+
assert _scheduler_config["interval"] == "epoch" and _scheduler_config["frequency"] == 1
192+
_scheduler_config["scheduler"] = self._swa_scheduler
190193

191194
if trainer.lr_schedulers:
192195
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
193196
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
194-
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
197+
trainer.lr_schedulers[0] = _scheduler_config
195198
else:
196-
_scheduler_config = _get_default_scheduler_config()
197-
_scheduler_config["scheduler"] = self._swa_scheduler
198199
trainer.lr_schedulers.append(_scheduler_config)
199200

200201
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)

tests/callbacks/test_pruning.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self):
3636
self.layer = Sequential(
3737
OrderedDict([
3838
("mlp_1", nn.Linear(32, 32)),
39-
("mlp_2", nn.Linear(32, 32)),
39+
("mlp_2", nn.Linear(32, 32, bias=False)),
4040
("mlp_3", nn.Linear(32, 2)),
4141
])
4242
)
@@ -85,7 +85,10 @@ def train_with_pruning_callback(
8585
if parameters_to_prune:
8686
pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")]
8787
else:
88-
pruning_kwargs["parameter_names"] = ["weight"]
88+
if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
89+
pruning_kwargs["parameter_names"] = ["weight"]
90+
else:
91+
pruning_kwargs["parameter_names"] = ["weight", "bias"]
8992
if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
9093
pruning_kwargs["pruning_dim"] = 0
9194
if pruning_fn == "ln_structured":
@@ -249,14 +252,14 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool
249252
actual = [m for m in actual if m.startswith("Applied")]
250253
assert actual == [
251254
"Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)",
252-
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501
253-
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501
255+
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 500 (48.83%)", # noqa: E501
256+
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 44 (68.75%)", # noqa: E501
254257
"Applied `RandomUnstructured`. Pruned: 544/1122 (48.48%) -> 680/1122 (60.61%)",
255-
"Applied `RandomUnstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.25. Pruned: 506 (49.41%) -> 633 (61.82%)", # noqa: E501
256-
"Applied `RandomUnstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.25. Pruned: 38 (59.38%) -> 47 (73.44%)", # noqa: E501
258+
"Applied `RandomUnstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.25. Pruned: 500 (48.83%) -> 635 (62.01%)", # noqa: E501
259+
"Applied `RandomUnstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.25. Pruned: 44 (68.75%) -> 45 (70.31%)", # noqa: E501
257260
"Applied `L1Unstructured`. Pruned: 680/1122 (60.61%) -> 884/1122 (78.79%)",
258-
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501
259-
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501
261+
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 635 (62.01%) -> 830 (81.05%)", # noqa: E501
262+
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 45 (70.31%) -> 54 (84.38%)", # noqa: E501
260263
]
261264

262265
filepath = str(tmpdir / "foo.ckpt")

tests/callbacks/test_stochastic_weight_avg.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,18 @@
2727

2828
if _TORCH_GREATER_EQUAL_1_6:
2929
from pytorch_lightning.callbacks import StochasticWeightAveraging
30+
from torch.optim.swa_utils import SWALR
3031

3132
class SwaTestModel(BoringModel):
3233

33-
def __init__(self, batchnorm: bool = True):
34+
def __init__(self, batchnorm: bool = True, interval: str = "epoch"):
3435
super().__init__()
3536
layers = [nn.Linear(32, 32)]
3637
if batchnorm:
3738
layers.append(nn.BatchNorm1d(32))
3839
layers += [nn.ReLU(), nn.Linear(32, 2)]
3940
self.layer = nn.Sequential(*layers)
41+
self.interval = interval
4042

4143
def training_step(self, batch, batch_idx):
4244
output = self.forward(batch)
@@ -46,6 +48,14 @@ def training_step(self, batch, batch_idx):
4648
def train_dataloader(self):
4749
return DataLoader(RandomDataset(32, 64), batch_size=2)
4850

51+
def configure_optimizers(self):
52+
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
53+
return {
54+
"optimizer": optimizer,
55+
"scheduler": torch.optim.lr_scheduler.StepLR(optimizer, step_size=1),
56+
"interval": self.interval,
57+
}
58+
4959
class SwaTestCallback(StochasticWeightAveraging):
5060
update_parameters_calls: int = 0
5161
transfer_weights_calls: int = 0
@@ -61,6 +71,10 @@ def transfer_weights(self, *args, **kwargs):
6171
def on_train_epoch_start(self, trainer, *args):
6272
super().on_train_epoch_start(trainer, *args)
6373
assert trainer.train_loop._skip_backward == (trainer.current_epoch > self.swa_end)
74+
if self.swa_start <= trainer.current_epoch:
75+
assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR)
76+
assert trainer.lr_schedulers[0]["interval"] == "epoch"
77+
assert trainer.lr_schedulers[0]["frequency"] == 1
6478

6579
def on_train_epoch_end(self, trainer, *args):
6680
super().on_train_epoch_end(trainer, *args)
@@ -89,8 +103,8 @@ def on_train_end(self, trainer, pl_module):
89103

90104

91105
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
92-
def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1):
93-
model = SwaTestModel(batchnorm=batchnorm)
106+
def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_processes=1, interval="epoch"):
107+
model = SwaTestModel(batchnorm=batchnorm, interval=interval)
94108
swa_start = 2
95109
max_epochs = 5
96110
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1)
@@ -140,6 +154,12 @@ def test_swa_callback(tmpdir, batchnorm: bool):
140154
train_with_swa(tmpdir, batchnorm=batchnorm)
141155

142156

157+
@RunIf(min_torch="1.6.0")
158+
@pytest.mark.parametrize("interval", ("epoch", "step"))
159+
def test_swa_callback_scheduler_step(tmpdir, interval: bool):
160+
train_with_swa(tmpdir, interval=interval)
161+
162+
143163
@RunIf(min_torch="1.6.0")
144164
def test_swa_raises():
145165
with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):

0 commit comments

Comments
 (0)