Skip to content

Commit e477e51

Browse files
author
Daniel Dale
committed
replaced test_deep_nested_model with test_complex_nested_model as it was a superset, updated changelog
1 parent 4c0baa3 commit e477e51

File tree

2 files changed

+3
-51
lines changed

2 files changed

+3
-51
lines changed

CHANGELOG.md

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535
- Added `clip_grad_by_value` support for TPUs ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
3636

3737

38-
- Added support for passing any class to `is_overridden` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918))
39-
40-
4138
- Added `sub_dir` parameter to `TensorBoardLogger` ([#6195](https://github.com/PyTorchLightning/pytorch-lightning/pull/6195))
4239

4340

@@ -71,9 +68,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7168
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))
7269

7370

74-
- Added IPU Accelerator ([#7867](https://github.com/PyTorchLightning/pytorch-lightning/pull/7867))
75-
76-
7771
- Added a warning if `Trainer(log_every_n_steps)` is a value too high for the training dataloader ([#7734](https://github.com/PyTorchLightning/pytorch-lightning/pull/7734))
7872

7973

@@ -178,12 +172,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
178172
- Deprecated `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`. ([#7891](https://github.com/PyTorchLightning/pytorch-lightning/pull/7891))
179173

180174

181-
- Deprecated `is_overridden(model=...)` in favor of `is_overridden(instance=...)` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918))
182-
183-
184-
- Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907))
185-
186-
187175
### Removed
188176

189177
- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))
@@ -208,6 +196,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
208196

209197

210198
### Fixed
199+
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))
211200

212201
- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868))
213202

tests/callbacks/test_finetuning_callback.py

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -307,44 +307,7 @@ def configure_optimizers(self):
307307
trainer.fit(model)
308308

309309

310-
def test_deep_nested_model():
311-
312-
class ConvBlock(nn.Module):
313-
314-
def __init__(self, in_channels, out_channels):
315-
super().__init__()
316-
self.conv = nn.Conv2d(in_channels, out_channels, 3)
317-
self.act = nn.ReLU()
318-
self.bn = nn.BatchNorm2d(out_channels)
319-
320-
def forward(self, x):
321-
x = self.conv(x)
322-
x = self.act(x)
323-
return self.bn(x)
324-
325-
model = nn.Sequential(
326-
OrderedDict([
327-
("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))),
328-
("decoder", ConvBlock(128, 10)),
329-
])
330-
)
331-
332-
# There's 9 leaf layers in that model
333-
assert len(BaseFinetuning.flatten_modules(model)) == 9
334-
335-
BaseFinetuning.freeze(model.encoder, train_bn=True)
336-
assert not model.encoder[0].conv.weight.requires_grad
337-
assert model.encoder[0].bn.weight.requires_grad
338-
339-
BaseFinetuning.make_trainable(model)
340-
encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True))
341-
# The 8 parameters of the encoder are:
342-
# conv0.weight, conv0.bias, bn0.weight, bn0.bias
343-
# conv1.weight, conv1.bias, bn1.weight, bn1.bias
344-
assert len(encoder_params) == 8
345-
346-
347-
def test_parent_module_w_param_model():
310+
def test_complex_nested_model():
348311
"""Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
349312
directly themselves rather than exclusively their submodules containing parameters.
350313
"""
@@ -368,7 +331,7 @@ def __init__(self, in_channels, out_channels):
368331
super().__init__()
369332
self.conv = nn.Conv2d(in_channels, out_channels, 3)
370333
self.act = nn.ReLU()
371-
# add trivial test parameter to conv block to validate parent (non-leaf) module parameter handling
334+
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
372335
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
373336
self.bn = nn.BatchNorm2d(out_channels)
374337

0 commit comments

Comments
 (0)