Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions torchsummary/tests/test_models/test_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


class SingleInputNet(nn.Module):
def __init__(self):
Expand All @@ -19,6 +21,7 @@ def forward(self, x):
x = self.fc2(x)
return F.log_softmax(x, dim=1)


class MultipleInputNet(nn.Module):
def __init__(self):
super(MultipleInputNet, self).__init__()
Expand All @@ -36,6 +39,7 @@ def forward(self, x1, x2):
x = torch.cat((x1, x2), 0)
return F.log_softmax(x, dim=1)


class MultipleInputNetDifferentDtypes(nn.Module):
def __init__(self):
super(MultipleInputNetDifferentDtypes, self).__init__()
Expand All @@ -54,3 +58,45 @@ def forward(self, x1, x2):
# set x2 to FloatTensor
x = torch.cat((x1, x2), 0)
return F.log_softmax(x, dim=1)


class NestedNet(nn.Module):
def __init__(self):
super(NestedNet, self).__init__()
self.conv_block1 = ConvBlock(1, 10, 5)
self.conv_block2 = ConvBlock(10, 20, 5)
self.conv_drop = nn.Dropout2d(0.3)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
x = F.relu(self.conv_block1(x))
x = F.relu((self.conv_drop(self.conv_block2(x))))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size)
self.bn = nn.BatchNorm2d(out_channels)
self.pool = nn.MaxPool2d(2, stride=2)

def forward(self, x):
x = F.relu(self.conv(x))
x = self.bn(x)
x = self.pool(x)
return x


class CustomModule(nn.Module):
def __init__(self):
super(CustomModule, self).__init__()
weight_tensor = torch.rand(50, 50)
self.W = Parameter(weight_tensor, requires_grad=True)

def forward(self, x):
return torch.einsum("bij,jk->bik", x, self.W)
35 changes: 32 additions & 3 deletions torchsummary/tests/unit_tests/torchsummary_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import unittest
from torchsummary import summary, summary_string
from torchsummary.tests.test_models.test_model import SingleInputNet, MultipleInputNet, MultipleInputNetDifferentDtypes
from torchsummary.torchsummary import _build_summary_dict, _build_summary_string
from torchsummary.tests.test_models.test_model import SingleInputNet, MultipleInputNet, \
MultipleInputNetDifferentDtypes, NestedNet, CustomModule
import torch

gpu_if_available = "cuda:0" if torch.cuda.is_available() else "cpu"

class torchsummaryTests(unittest.TestCase):

class TorchSummaryTests(unittest.TestCase):
def test_single_input(self):
model = SingleInputNet()
input = (1, 28, 28)
Expand Down Expand Up @@ -48,8 +51,34 @@ def test_multiple_input_types(self):
self.assertEqual(total_params, 31120)
self.assertEqual(trainable_params, 31120)

def test_recursive(self):
model = NestedNet()
input = (1, 28, 28)
summary = _build_summary_dict(model, [input], device='cpu')
summary_str, (total_params, trainable_params) = _build_summary_string(summary, [input])

self.assertListEqual(list(summary.keys()), ['Conv2d-1', 'BatchNorm2d-2', 'MaxPool2d-3', 'ConvBlock-4',
'Conv2d-5', 'BatchNorm2d-6', 'MaxPool2d-7', 'ConvBlock-8',
'Dropout2d-9', 'Linear-10', 'Linear-11', 'NestedNet-12'])
self.assertEqual(total_params, 21900)
self.assertEqual(trainable_params, 21900)

summary = _build_summary_dict(model, [input], device='cpu', recurse=False)
summary_str, (total_params, trainable_params) = _build_summary_string(summary, [input])
self.assertListEqual(list(summary.keys()), ['ConvBlock-1', 'ConvBlock-2', 'Dropout2d-3', 'Linear-4',
'Linear-5', 'NestedNet-6'])
self.assertEqual(total_params, 21900)
self.assertEqual(trainable_params, 21900)

def test_custom_module(self):
model = CustomModule()
input = (1, 50)
total_params, trainable_params = summary(model, input, device='cpu')
self.assertEqual(total_params, 2500)
self.assertEqual(trainable_params, 2500)


class torchsummarystringTests(unittest.TestCase):
class TorchSummaryStringTests(unittest.TestCase):
def test_single_input(self):
model = SingleInputNet()
input = (1, 28, 28)
Expand Down
56 changes: 34 additions & 22 deletions torchsummary/torchsummary.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
import torch
import torch.nn as nn
from torch.autograd import Variable

from collections import OrderedDict
import numpy as np


def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None):
result, params_info = summary_string(
model, input_size, batch_size, device, dtypes)
def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None, recurse=True):
result, params_info = summary_string(model, input_size, batch_size, device, dtypes, recurse)
print(result)

return params_info


def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None):
def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None, recurse=True):
# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]

summary = _build_summary_dict(
model, input_size, batch_size, device, dtypes, recurse)
return _build_summary_string(summary, input_size, batch_size)


def _build_summary_dict(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None, recurse=True):
if dtypes == None:
dtypes = [torch.FloatTensor]*len(input_size)

summary_str = ''

def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
Expand All @@ -37,24 +43,21 @@ def hook(module, input, output):
summary[m_key]["output_shape"] = list(output.size())
summary[m_key]["output_shape"][0] = batch_size

params = 0
if hasattr(module, "weight") and hasattr(module.weight, "size"):
params += torch.prod(torch.LongTensor(list(module.weight.size())))
summary[m_key]["trainable"] = module.weight.requires_grad
if hasattr(module, "bias") and hasattr(module.bias, "size"):
params += torch.prod(torch.LongTensor(list(module.bias.size())))
summary[m_key]["nb_params"] = params
nb_params = 0
trainable_params = 0
for name, p in module.named_parameters():
params = torch.numel(p)
nb_params += params
trainable_params += params if p.requires_grad else 0
summary[m_key]["nb_params"] = nb_params
summary[m_key]["trainable"] = trainable_params

if (
not isinstance(module, nn.Sequential)
and not isinstance(module, nn.ModuleList)
):
hooks.append(module.register_forward_hook(hook))

# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]

# batch_size of 2 for batchnorm
x = [torch.rand(2, *in_size).type(dtype).to(device=device)
for in_size, dtype in zip(input_size, dtypes)]
Expand All @@ -64,7 +67,11 @@ def hook(module, input, output):
hooks = []

# register hook
model.apply(register_hook)
if recurse:
model.apply(register_hook)
else:
[register_hook(m) for m in model.children()]
register_hook(model)

# make a forward pass
# print(x.shape)
Expand All @@ -74,6 +81,12 @@ def hook(module, input, output):
for h in hooks:
h.remove()

return summary


def _build_summary_string(summary, input_size, batch_size=-1):

summary_str = ''
summary_str += "----------------------------------------------------------------" + "\n"
line_new = "{:>20} {:>25} {:>15}".format(
"Layer (type)", "Output Shape", "Param #")
Expand All @@ -89,12 +102,11 @@ def hook(module, input, output):
str(summary[layer]["output_shape"]),
"{0:,}".format(summary[layer]["nb_params"]),
)
total_params += summary[layer]["nb_params"]
total_params = summary[layer]["nb_params"]

total_output += np.prod(summary[layer]["output_shape"])
if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"]
trainable_params = summary[layer]["trainable"]
summary_str += line_new + "\n"

# assume 4 bytes/number (float on cuda).
Expand Down