@@ -307,44 +307,7 @@ def configure_optimizers(self):
307307 trainer .fit (model )
308308
309309
310- def test_deep_nested_model ():
311-
312- class ConvBlock (nn .Module ):
313-
314- def __init__ (self , in_channels , out_channels ):
315- super ().__init__ ()
316- self .conv = nn .Conv2d (in_channels , out_channels , 3 )
317- self .act = nn .ReLU ()
318- self .bn = nn .BatchNorm2d (out_channels )
319-
320- def forward (self , x ):
321- x = self .conv (x )
322- x = self .act (x )
323- return self .bn (x )
324-
325- model = nn .Sequential (
326- OrderedDict ([
327- ("encoder" , nn .Sequential (ConvBlock (3 , 64 ), ConvBlock (64 , 128 ))),
328- ("decoder" , ConvBlock (128 , 10 )),
329- ])
330- )
331-
332- # There's 9 leaf layers in that model
333- assert len (BaseFinetuning .flatten_modules (model )) == 9
334-
335- BaseFinetuning .freeze (model .encoder , train_bn = True )
336- assert not model .encoder [0 ].conv .weight .requires_grad
337- assert model .encoder [0 ].bn .weight .requires_grad
338-
339- BaseFinetuning .make_trainable (model )
340- encoder_params = list (BaseFinetuning .filter_params (model .encoder , train_bn = True ))
341- # The 8 parameters of the encoder are:
342- # conv0.weight, conv0.bias, bn0.weight, bn0.bias
343- # conv1.weight, conv1.bias, bn1.weight, bn1.bias
344- assert len (encoder_params ) == 8
345-
346-
347- def test_parent_module_w_param_model ():
310+ def test_complex_nested_model ():
348311 """Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
349312 directly themselves rather than exclusively their submodules containing parameters.
350313 """
@@ -368,7 +331,7 @@ def __init__(self, in_channels, out_channels):
368331 super ().__init__ ()
369332 self .conv = nn .Conv2d (in_channels , out_channels , 3 )
370333 self .act = nn .ReLU ()
371- # add trivial test parameter to conv block to validate parent (non-leaf) module parameter handling
334+ # add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
372335 self .parent_param = nn .Parameter (torch .zeros ((1 ), dtype = torch .float ))
373336 self .bn = nn .BatchNorm2d (out_channels )
374337
0 commit comments