Skip to content

Commit 81df232

Browse files
authored
Add size information
Added information for estimating the total size of the model. Estimates taken from here: http://jacobkimmel.github.io/pytorch_estimating_model_size/ -calculates size of input, parameters, and forward/backward pass intermediate variables -prints out these estimates and total -batch_size optional input
1 parent 6d9f77c commit 81df232

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

torchsummary/torchsummary.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from torch.autograd import Variable
44

55
from collections import OrderedDict
6+
import numpy as np
67

7-
8-
def summary(model, input_size, device="cuda"):
8+
def summary(model, input_size, batch_size=-1,device="cuda"):
99
def register_hook(module):
1010
def hook(module, input, output):
1111
class_name = str(module.__class__).split('.')[-1].split("'")[0]
@@ -14,12 +14,12 @@ def hook(module, input, output):
1414
m_key = '%s-%i' % (class_name, module_idx+1)
1515
summary[m_key] = OrderedDict()
1616
summary[m_key]['input_shape'] = list(input[0].size())
17-
summary[m_key]['input_shape'][0] = -1
17+
summary[m_key]['input_shape'][0] = batch_size
1818
if isinstance(output, (list,tuple)):
1919
summary[m_key]['output_shape'] = [[-1] + list(o.size())[1:] for o in output]
2020
else:
2121
summary[m_key]['output_shape'] = list(output.size())
22-
summary[m_key]['output_shape'][0] = -1
22+
summary[m_key]['output_shape'][0] = batch_size
2323

2424
params = 0
2525
if hasattr(module, 'weight') and hasattr(module.weight, 'size'):
@@ -67,18 +67,31 @@ def hook(module, input, output):
6767
print(line_new)
6868
print('================================================================')
6969
total_params = 0
70+
total_output = 0
7071
trainable_params = 0
7172
for layer in summary:
7273
# input_shape, output_shape, trainable, nb_params
7374
line_new = '{:>20} {:>25} {:>15}'.format(layer, str(summary[layer]['output_shape']), '{0:,}'.format(summary[layer]['nb_params']))
7475
total_params += summary[layer]['nb_params']
76+
total_output += np.prod(summary[layer]['output_shape'])
7577
if 'trainable' in summary[layer]:
7678
if summary[layer]['trainable'] == True:
7779
trainable_params += summary[layer]['nb_params']
7880
print(line_new)
81+
#assume 4 bytes/number (float on cuda).
82+
total_input_size = abs(np.prod(input_size)*batch_size*4./(1024**2.))
83+
total_output_size = abs(2.*total_output*4./(1024**2.)) #x2 for gradients
84+
total_params_size = abs(total_params.numpy()*4./(1024**2.))
85+
total_size = total_params_size + total_output_size + total_input_size
86+
7987
print('================================================================')
8088
print('Total params: {0:,}'.format(total_params))
8189
print('Trainable params: {0:,}'.format(trainable_params))
8290
print('Non-trainable params: {0:,}'.format(total_params - trainable_params))
8391
print('----------------------------------------------------------------')
84-
# return summary
92+
print('Input size (MB): %0.2f' % total_input_size)
93+
print('Forward/backward pass size (MB): %0.2f' % total_output_size)
94+
print('Params size (MB): %0.2f' % total_params_size)
95+
print('Estimated Total Size (MB): %0.2f' % total_size)
96+
print('----------------------------------------------------------------')
97+
#return summary

0 commit comments

Comments
 (0)