Skip to content

Commit 8b8a8cf

Browse files
author
Daniel Dale
committed
replaced test_deep_nested_model with test_complex_nested_model as it was a superset, updated changelog
1 parent 7fdedf9 commit 8b8a8cf

File tree

2 files changed

+3
-39
lines changed

2 files changed

+3
-39
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
208208

209209

210210
### Fixed
211+
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))
211212

212213
- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))
213214

tests/callbacks/test_finetuning_callback.py

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -307,44 +307,7 @@ def configure_optimizers(self):
307307
trainer.fit(model)
308308

309309

310-
def test_deep_nested_model():
311-
312-
class ConvBlock(nn.Module):
313-
314-
def __init__(self, in_channels, out_channels):
315-
super().__init__()
316-
self.conv = nn.Conv2d(in_channels, out_channels, 3)
317-
self.act = nn.ReLU()
318-
self.bn = nn.BatchNorm2d(out_channels)
319-
320-
def forward(self, x):
321-
x = self.conv(x)
322-
x = self.act(x)
323-
return self.bn(x)
324-
325-
model = nn.Sequential(
326-
OrderedDict([
327-
("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))),
328-
("decoder", ConvBlock(128, 10)),
329-
])
330-
)
331-
332-
# There's 9 leaf layers in that model
333-
assert len(BaseFinetuning.flatten_modules(model)) == 9
334-
335-
BaseFinetuning.freeze(model.encoder, train_bn=True)
336-
assert not model.encoder[0].conv.weight.requires_grad
337-
assert model.encoder[0].bn.weight.requires_grad
338-
339-
BaseFinetuning.make_trainable(model)
340-
encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True))
341-
# The 8 parameters of the encoder are:
342-
# conv0.weight, conv0.bias, bn0.weight, bn0.bias
343-
# conv1.weight, conv1.bias, bn1.weight, bn1.bias
344-
assert len(encoder_params) == 8
345-
346-
347-
def test_parent_module_w_param_model():
310+
def test_complex_nested_model():
348311
"""Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
349312
directly themselves rather than exclusively their submodules containing parameters.
350313
"""
@@ -368,7 +331,7 @@ def __init__(self, in_channels, out_channels):
368331
super().__init__()
369332
self.conv = nn.Conv2d(in_channels, out_channels, 3)
370333
self.act = nn.ReLU()
371-
# add trivial test parameter to conv block to validate parent (non-leaf) module parameter handling
334+
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
372335
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
373336
self.bn = nn.BatchNorm2d(out_channels)
374337

0 commit comments

Comments
 (0)