@@ -342,3 +342,59 @@ def forward(self, x):
342342 # conv0.weight, conv0.bias, bn0.weight, bn0.bias
343343 # conv1.weight, conv1.bias, bn1.weight, bn1.bias
344344 assert len (encoder_params ) == 8
345+
346+
347+ def test_parent_module_w_param_model ():
348+ """Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
349+ directly themselves rather than exclusively their submodules containing parameters.
350+ """
351+
352+ class ConvBlock (nn .Module ):
353+
354+ def __init__ (self , in_channels , out_channels ):
355+ super ().__init__ ()
356+ self .conv = nn .Conv2d (in_channels , out_channels , 3 )
357+ self .act = nn .ReLU ()
358+ self .bn = nn .BatchNorm2d (out_channels )
359+
360+ def forward (self , x ):
361+ x = self .conv (x )
362+ x = self .act (x )
363+ return self .bn (x )
364+
365+ class ConvBlockParam (nn .Module ):
366+
367+ def __init__ (self , in_channels , out_channels ):
368+ super ().__init__ ()
369+ self .conv = nn .Conv2d (in_channels , out_channels , 3 )
370+ self .act = nn .ReLU ()
371+ # add trivial test parameter to conv block to validate parent (non-leaf) module parameter handling
372+ self .parent_param = nn .Parameter (torch .zeros ((1 ), dtype = torch .float ))
373+ self .bn = nn .BatchNorm2d (out_channels )
374+
375+ def forward (self , x ):
376+ x = self .conv (x )
377+ x = self .act (x )
378+ return self .bn (x )
379+
380+ model = nn .Sequential (
381+ OrderedDict ([
382+ ("encoder" , nn .Sequential (ConvBlockParam (3 , 64 ), ConvBlock (64 , 128 ))),
383+ ("decoder" , ConvBlock (128 , 10 )),
384+ ])
385+ )
386+
387+ # There are 10 leaf modules or parent modules w/ parameters in the test model
388+ assert len (BaseFinetuning .flatten_modules (model )) == 10
389+
390+ BaseFinetuning .freeze (model .encoder , train_bn = True )
391+ assert not model .encoder [0 ].conv .weight .requires_grad # Validate a leaf module parameter is frozen
392+ assert not model .encoder [0 ].parent_param .requires_grad # Validate the parent module parameter is frozen
393+ assert model .encoder [0 ].bn .weight .requires_grad
394+
395+ BaseFinetuning .make_trainable (model )
396+ encoder_params = list (BaseFinetuning .filter_params (model .encoder , train_bn = True ))
397+ # The 9 parameters of the encoder are:
398+ # conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param
399+ # conv1.weight, conv1.bias, bn1.weight, bn1.bias
400+ assert len (encoder_params ) == 9
0 commit comments