3
3
from torch .autograd import Variable
4
4
5
5
from collections import OrderedDict
6
+ import numpy as np
6
7
7
-
8
- def summary (model , input_size , device = "cuda" ):
8
+ def summary (model , input_size , batch_size = - 1 ,device = "cuda" ):
9
9
def register_hook (module ):
10
10
def hook (module , input , output ):
11
11
class_name = str (module .__class__ ).split ('.' )[- 1 ].split ("'" )[0 ]
@@ -14,12 +14,12 @@ def hook(module, input, output):
14
14
m_key = '%s-%i' % (class_name , module_idx + 1 )
15
15
summary [m_key ] = OrderedDict ()
16
16
summary [m_key ]['input_shape' ] = list (input [0 ].size ())
17
- summary [m_key ]['input_shape' ][0 ] = - 1
17
+ summary [m_key ]['input_shape' ][0 ] = batch_size
18
18
if isinstance (output , (list ,tuple )):
19
19
summary [m_key ]['output_shape' ] = [[- 1 ] + list (o .size ())[1 :] for o in output ]
20
20
else :
21
21
summary [m_key ]['output_shape' ] = list (output .size ())
22
- summary [m_key ]['output_shape' ][0 ] = - 1
22
+ summary [m_key ]['output_shape' ][0 ] = batch_size
23
23
24
24
params = 0
25
25
if hasattr (module , 'weight' ) and hasattr (module .weight , 'size' ):
@@ -67,18 +67,31 @@ def hook(module, input, output):
67
67
print (line_new )
68
68
print ('================================================================' )
69
69
total_params = 0
70
+ total_output = 0
70
71
trainable_params = 0
71
72
for layer in summary :
72
73
# input_shape, output_shape, trainable, nb_params
73
74
line_new = '{:>20} {:>25} {:>15}' .format (layer , str (summary [layer ]['output_shape' ]), '{0:,}' .format (summary [layer ]['nb_params' ]))
74
75
total_params += summary [layer ]['nb_params' ]
76
+ total_output += np .prod (summary [layer ]['output_shape' ])
75
77
if 'trainable' in summary [layer ]:
76
78
if summary [layer ]['trainable' ] == True :
77
79
trainable_params += summary [layer ]['nb_params' ]
78
80
print (line_new )
81
+ #assume 4 bytes/number (float on cuda).
82
+ total_input_size = abs (np .prod (input_size )* batch_size * 4. / (1024 ** 2. ))
83
+ total_output_size = abs (2. * total_output * 4. / (1024 ** 2. )) #x2 for gradients
84
+ total_params_size = abs (total_params .numpy ()* 4. / (1024 ** 2. ))
85
+ total_size = total_params_size + total_output_size + total_input_size
86
+
79
87
print ('================================================================' )
80
88
print ('Total params: {0:,}' .format (total_params ))
81
89
print ('Trainable params: {0:,}' .format (trainable_params ))
82
90
print ('Non-trainable params: {0:,}' .format (total_params - trainable_params ))
83
91
print ('----------------------------------------------------------------' )
84
- # return summary
92
+ print ('Input size (MB): %0.2f' % total_input_size )
93
+ print ('Forward/backward pass size (MB): %0.2f' % total_output_size )
94
+ print ('Params size (MB): %0.2f' % total_params_size )
95
+ print ('Estimated Total Size (MB): %0.2f' % total_size )
96
+ print ('----------------------------------------------------------------' )
97
+ #return summary
0 commit comments