Skip to content

Unable to handle custom weight  #95

@kim-sunghoon

Description

@kim-sunghoon

Hello, Thanks for the awesome tool!
However, customized weights are not reported by torchsummary.
[the weight1 and weight2 in the following code]

For example, I made a following model just using 1x1x3x3 convolution then get
input featuremap size = C x H x W -->
output featuremap size = C x H x W

  1 import torch                                                                                                                                                                                                                                             
  2 import torch.nn as nn                                                                                                                                                                                                                                    
  3 import torch.nn.functional as F                                                                                                                                                                                                                          
  4 import torchsummary                                                                                                                                                                                                                                      
  5 import numpy                                                                                                                                                                                                                                             
  6 import sys                                                                                                                                                                                                                                               
  7                                                                                                                                                                                                                                                          
  8 class test(nn.Module):                                                                                                                                                                                                                                   
  9     def __init__(self, planes):                                                                                                                                                                                                                          
 10     ¦   super(test, self).__init__()                                                                                                                                                                                                                     
 11                                                                                                                                                                                                                                                          
 12     ¦   self.planes = planes                                                                                                                                                                                                                             
 13     ¦   self.weight1 = torch.randn((1,1,3,3), requires_grad = True).to("cuda")                                                                                                                                                                           
 14     ¦   self.weight2 = torch.randn((1,1,3,3), requires_grad = True).to("cuda")                                                                                                                                                                           
 15     ¦   print(self.weight1)                                                                                                                                                                                                                              
 16     ¦   print(self.weight2)                                                                                                                                                                                                                              
 17     ¦   self.bn1 = nn.BatchNorm2d(planes)                                                                                                                                                                                                                
 18     ¦   self.bn2 = nn.BatchNorm2d(planes)                                                                                                                                                                                                                
 19                                                                                                                                                                                                                                                          
 20     def forward(self, x):                                                                                                                                                                                                                                
 21     ¦   out = None                                                                                                                                                                                                                                       
 22     ¦   for i in range (self.planes):                                                                                                                                                                                                                    
 23     ¦   ¦   if i == 0:                                                                                                                                                                                                                                   
 24     ¦   ¦   ¦   out = F.conv2d(x[:,i,:,:].unsqueeze(1), self.weight1, stride=1, padding=1)                                                                                                                                                               
 25     ¦   ¦   else:                                                                                                                                                                                                                                        
 26     ¦   ¦   ¦   temp_out = F.conv2d(x[:,i,:,:].unsqueeze(1), self.weight1, stride=1, padding=1)                                                                                                                                                          
 27     ¦   ¦   ¦   out = torch.cat(([out, temp_out]), dim=1)                                                                                                                                                                                                
 28     ¦   out = self.bn1(out)                                                                                                                                                                                                                              
 29     ¦   for i in range (self.planes):                                                                                                                                                                                                                    
 30     ¦   ¦   if i == 0:                                                                                                                                                                                                                                   
 31     ¦   ¦   ¦   out = F.conv2d(x[:,i,:,:].unsqueeze(1), self.weight2, stride=1, padding=1)                                                                                                                                                               
 32     ¦   ¦   else:                                                                                                                                                                                                                                        
 33     ¦   ¦   ¦   temp_out = F.conv2d(x[:,i,:,:].unsqueeze(1), self.weight2, stride=1, padding=1)                                                                                                                                                          
 34     ¦   ¦   ¦   out = torch.cat(([out, temp_out]), dim=1)                                                                                                                                                                                                
 35     ¦   out = self.bn2(out)                                                                                                                                                                                                                              
 36                                                                                                                                                                                                                                                          
 37     ¦   return out                                                                                                                                                                                                                                       
 38                                                                                                                                                                                                                                                          
 39                                                                                                                                                                                                                                                          
 40                                                                                                                                                                                                                                                          
 41                                                                                                                                                                                                                                                          
 42                                                                                                                                                                                                                                                          
 43 if __name__ == "__main__":                                                                                                                                                                                                                               
 44     device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")                                                                                                                                                                              
 45                                                                                                                                                                                                                                                          
 46     input1 = torch.ones(1,5,10,10).to(device)                                                                                                                                                                                                            
 47     print(input1)                                                                                                                                                                                                                                        
 48     #  print(input1)                                                                                                                                                                                                                                     
 49     planes = input1.shape[1]                                                                                                                                                                                                                             
 50     #  print(input1[:,0,:,:])                                                                                                                                                                                                                            
 51     model = test(planes)                                                                                                                                                                                                                                 
 52     model = model.to(device)                                                                                                                                                                                                                             
 53     torchsummary.summary(model, (5,10,10))                                                                                                                                                                                                              
 54     output = model(input1)                                                                                                                                                                                                                               
 55     print("Total Output ")                                                                                                                                                                                                                               
 56     print(output)                                                                                                                                                                                                                                        
 57     print(output.size())                                                                                                                                                                                                                                 

Line 53 generated the only batchnorm params as follows:

----------------------------------------------------------------                                                                                                                                                                                             
    ¦   Layer (type)               Output Shape         Param #                                                                                                                                                                                              
================================================================                                                                                                                                                                                             
    ¦  BatchNorm2d-1            [-1, 5, 10, 10]              10                                                                                                                                                                                              
    ¦  BatchNorm2d-2            [-1, 5, 10, 10]              10                                                                                                                                                                                              
================================================================                                                                                                                                                                                             
Total params: 20                                                                                                                                                                                                                                             
Trainable params: 20                                                                                                                                                                                                                                         
Non-trainable params: 0                                                                                                                                                                                                                                      
----------------------------------------------------------------                                                                                                                                                                                             
Input size (MB): 0.00                                                                                                                                                                                                                                        
Forward/backward pass size (MB): 0.01                                                                                                                                                                                                                        
Params size (MB): 0.00                                                                                                                                                                                                                                       
Estimated Total Size (MB): 0.01                                                                                                                                                                                                                              
----------------------------------------------------------------    

Is there any way to monitor customized weights?

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions