Skip to content

Commit 7fdedf9

Browse files
author
Daniel Dale
committed
bugfix for #7930 w/ associated new test
1 parent 96433d0 commit 7fdedf9

File tree

2 files changed

+66
-7
lines changed

2 files changed

+66
-7
lines changed

pytorch_lightning/callbacks/finetuning.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def on_load_checkpoint(
105105
@staticmethod
106106
def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
107107
"""
108-
This function is used to flatten a module or an iterable of modules into a list of its modules.
108+
This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules
109+
with no children) and parent modules that have parameters directly themselves.
109110
110111
Args:
111112
modules: A given module or an iterable of modules
@@ -121,8 +122,8 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
121122
else:
122123
_modules = modules.modules()
123124

124-
# Leaf nodes in the graph have no children, so we use that to filter
125-
return [m for m in _modules if not list(m.children())]
125+
# Capture all leaf modules as well as parent modules that have parameters directly themsleves
126+
return [m for m in _modules if not list(m.children()) or m._parameters]
126127

127128
@staticmethod
128129
def filter_params(
@@ -136,15 +137,15 @@ def filter_params(
136137
modules: A given module or an iterable of modules
137138
train_bn: Whether to train BatchNorm module
138139
requires_grad: Whether to create a generator for trainable or non-trainable parameters.
139-
140140
Returns:
141141
Generator
142142
"""
143143
modules = BaseFinetuning.flatten_modules(modules)
144144
for mod in modules:
145145
if isinstance(mod, _BatchNorm) and not train_bn:
146146
continue
147-
for param in mod.parameters():
147+
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
148+
for param in mod.parameters(recurse=False):
148149
if param.requires_grad == requires_grad:
149150
yield param
150151

@@ -158,7 +159,8 @@ def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) ->
158159
"""
159160
modules = BaseFinetuning.flatten_modules(modules)
160161
for module in modules:
161-
for param in module.parameters():
162+
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
163+
for param in module.parameters(recurse=False):
162164
param.requires_grad = True
163165

164166
@staticmethod
@@ -178,7 +180,8 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn:
178180
if isinstance(mod, _BatchNorm) and train_bn:
179181
BaseFinetuning.make_trainable(mod)
180182
else:
181-
for param in mod.parameters():
183+
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
184+
for param in mod.parameters(recurse=False):
182185
param.requires_grad = False
183186

184187
@staticmethod

tests/callbacks/test_finetuning_callback.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,59 @@ def forward(self, x):
342342
# conv0.weight, conv0.bias, bn0.weight, bn0.bias
343343
# conv1.weight, conv1.bias, bn1.weight, bn1.bias
344344
assert len(encoder_params) == 8
345+
346+
347+
def test_parent_module_w_param_model():
348+
"""Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
349+
directly themselves rather than exclusively their submodules containing parameters.
350+
"""
351+
352+
class ConvBlock(nn.Module):
353+
354+
def __init__(self, in_channels, out_channels):
355+
super().__init__()
356+
self.conv = nn.Conv2d(in_channels, out_channels, 3)
357+
self.act = nn.ReLU()
358+
self.bn = nn.BatchNorm2d(out_channels)
359+
360+
def forward(self, x):
361+
x = self.conv(x)
362+
x = self.act(x)
363+
return self.bn(x)
364+
365+
class ConvBlockParam(nn.Module):
366+
367+
def __init__(self, in_channels, out_channels):
368+
super().__init__()
369+
self.conv = nn.Conv2d(in_channels, out_channels, 3)
370+
self.act = nn.ReLU()
371+
# add trivial test parameter to conv block to validate parent (non-leaf) module parameter handling
372+
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
373+
self.bn = nn.BatchNorm2d(out_channels)
374+
375+
def forward(self, x):
376+
x = self.conv(x)
377+
x = self.act(x)
378+
return self.bn(x)
379+
380+
model = nn.Sequential(
381+
OrderedDict([
382+
("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))),
383+
("decoder", ConvBlock(128, 10)),
384+
])
385+
)
386+
387+
# There are 10 leaf modules or parent modules w/ parameters in the test model
388+
assert len(BaseFinetuning.flatten_modules(model)) == 10
389+
390+
BaseFinetuning.freeze(model.encoder, train_bn=True)
391+
assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen
392+
assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen
393+
assert model.encoder[0].bn.weight.requires_grad
394+
395+
BaseFinetuning.make_trainable(model)
396+
encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True))
397+
# The 9 parameters of the encoder are:
398+
# conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param
399+
# conv1.weight, conv1.bias, bn1.weight, bn1.bias
400+
assert len(encoder_params) == 9

0 commit comments

Comments
 (0)