Skip to content

Commit 1531bf5

Browse files
committed
Add ResNet, AlexNet, and VGG model definitions and model zoo
1 parent 3ed4831 commit 1531bf5

File tree

5 files changed

+325
-0
lines changed

5 files changed

+325
-0
lines changed

torchvision/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from torchvision import models
2+
from torchvision import datasets
3+
from torchvision import transforms
4+
from torchvision import utils

torchvision/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .alexnet import *
2+
from .resnet import *
3+
from .vgg import *

torchvision/models/alexnet.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch.nn as nn
2+
import torch.utils.model_zoo as model_zoo
3+
4+
5+
__all__ = ['AlexNet', 'alexnet']
6+
7+
8+
model_urls = {
9+
'alexnet': 'https://s3.amazonaws.com/pytorch/models/alexnet-owt-4df8aa71.pth',
10+
}
11+
12+
13+
class AlexNet(nn.Container):
14+
def __init__(self, num_classes=1000):
15+
super(AlexNet, self).__init__()
16+
self.features = nn.Sequential(
17+
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
18+
nn.ReLU(inplace=True),
19+
nn.MaxPool2d(kernel_size=3, stride=2),
20+
nn.Conv2d(64, 192, kernel_size=5, padding=2),
21+
nn.ReLU(inplace=True),
22+
nn.MaxPool2d(kernel_size=3, stride=2),
23+
nn.Conv2d(192, 384, kernel_size=3, padding=1),
24+
nn.ReLU(inplace=True),
25+
nn.Conv2d(384, 256, kernel_size=3, padding=1),
26+
nn.ReLU(inplace=True),
27+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
28+
nn.ReLU(inplace=True),
29+
nn.MaxPool2d(kernel_size=3, stride=2),
30+
)
31+
self.classifier = nn.Sequential(
32+
nn.Dropout(),
33+
nn.Linear(256 * 6 * 6, 4096),
34+
nn.ReLU(inplace=True),
35+
nn.Dropout(),
36+
nn.Linear(4096, 4096),
37+
nn.ReLU(inplace=True),
38+
nn.Linear(4096, num_classes),
39+
)
40+
41+
def forward(self, x):
42+
x = self.features(x)
43+
x = x.view(x.size(0), 256 * 6 * 6)
44+
x = self.classifier(x)
45+
return x
46+
47+
48+
def alexnet(pretrained=False):
49+
r"""AlexNet model architecture from the "One weird trick" paper.
50+
https://arxiv.org/abs/1404.5997
51+
"""
52+
model = AlexNet()
53+
if pretrained:
54+
model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
55+
return model

torchvision/models/resnet.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import torch.nn as nn
2+
import math
3+
import torch.utils.model_zoo as model_zoo
4+
5+
6+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7+
'resnet152']
8+
9+
10+
model_urls = {
11+
'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth',
12+
'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth',
13+
'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
14+
}
15+
16+
17+
def conv3x3(in_planes, out_planes, stride=1):
18+
"3x3 convolution with padding"
19+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20+
padding=1, bias=False)
21+
22+
23+
class BasicBlock(nn.Container):
24+
expansion = 1
25+
26+
def __init__(self, inplanes, planes, stride=1, downsample=None):
27+
super(BasicBlock, self).__init__()
28+
self.conv1 = conv3x3(inplanes, planes, stride)
29+
self.bn1 = nn.BatchNorm2d(planes)
30+
self.relu = nn.ReLU(inplace=True)
31+
self.conv2 = conv3x3(planes, planes)
32+
self.bn2 = nn.BatchNorm2d(planes)
33+
self.downsample = downsample
34+
self.stride = stride
35+
36+
def forward(self, x):
37+
residual = x
38+
39+
out = self.conv1(x)
40+
out = self.bn1(out)
41+
out = self.relu(out)
42+
43+
out = self.conv2(out)
44+
out = self.bn2(out)
45+
46+
if self.downsample is not None:
47+
residual = self.downsample(x)
48+
49+
out += residual
50+
out = self.relu(out)
51+
52+
return out
53+
54+
55+
class Bottleneck(nn.Container):
56+
expansion = 4
57+
58+
def __init__(self, inplanes, planes, stride=1, downsample=None):
59+
super(Bottleneck, self).__init__()
60+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61+
self.bn1 = nn.BatchNorm2d(planes)
62+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63+
padding=1, bias=False)
64+
self.bn2 = nn.BatchNorm2d(planes)
65+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
66+
self.bn3 = nn.BatchNorm2d(planes * 4)
67+
self.relu = nn.ReLU(inplace=True)
68+
self.downsample = downsample
69+
self.stride = stride
70+
71+
def forward(self, x):
72+
residual = x
73+
74+
out = self.conv1(x)
75+
out = self.bn1(out)
76+
out = self.relu(out)
77+
78+
out = self.conv2(out)
79+
out = self.bn2(out)
80+
out = self.relu(out)
81+
82+
out = self.conv3(out)
83+
out = self.bn3(out)
84+
85+
if self.downsample is not None:
86+
residual = self.downsample(x)
87+
88+
out += residual
89+
out = self.relu(out)
90+
91+
return out
92+
93+
94+
class ResNet(nn.Container):
95+
def __init__(self, block, layers, num_classes=1000):
96+
self.inplanes = 64
97+
super(ResNet, self).__init__()
98+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
99+
bias=False)
100+
self.bn1 = nn.BatchNorm2d(64)
101+
self.relu = nn.ReLU(inplace=True)
102+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
103+
self.layer1 = self._make_layer(block, 64, layers[0])
104+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
105+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
106+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
107+
self.avgpool = nn.AvgPool2d(7)
108+
self.fc = nn.Linear(512 * block.expansion, num_classes)
109+
110+
for m in self.modules():
111+
if isinstance(m, nn.Conv2d):
112+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
113+
m.weight.data.normal_(0, math.sqrt(2. / n))
114+
elif isinstance(m, nn.BatchNorm2d):
115+
m.weight.data.fill_(1)
116+
m.bias.data.zero_()
117+
118+
def _make_layer(self, block, planes, blocks, stride=1):
119+
downsample = None
120+
if stride != 1 or self.inplanes != planes * block.expansion:
121+
downsample = nn.Sequential(
122+
nn.Conv2d(self.inplanes, planes * block.expansion,
123+
kernel_size=1, stride=stride, bias=False),
124+
nn.BatchNorm2d(planes * block.expansion),
125+
)
126+
127+
layers = []
128+
layers.append(block(self.inplanes, planes, stride, downsample))
129+
self.inplanes = planes * block.expansion
130+
for i in range(1, blocks):
131+
layers.append(block(self.inplanes, planes))
132+
133+
return nn.Sequential(*layers)
134+
135+
def forward(self, x):
136+
x = self.conv1(x)
137+
x = self.bn1(x)
138+
x = self.relu(x)
139+
x = self.maxpool(x)
140+
141+
x = self.layer1(x)
142+
x = self.layer2(x)
143+
x = self.layer3(x)
144+
x = self.layer4(x)
145+
146+
x = self.avgpool(x)
147+
x = x.view(x.size(0), -1)
148+
x = self.fc(x)
149+
150+
return x
151+
152+
153+
def resnet18(pretrained=False):
154+
model = ResNet(BasicBlock, [2, 2, 2, 2])
155+
if pretrained:
156+
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
157+
return model
158+
159+
160+
def resnet34(pretrained=False):
161+
model = ResNet(BasicBlock, [3, 4, 6, 3])
162+
if pretrained:
163+
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
164+
return model
165+
166+
167+
def resnet50(pretrained=False):
168+
model = ResNet(Bottleneck, [3, 4, 6, 3])
169+
if pretrained:
170+
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
171+
return model
172+
173+
174+
def resnet101():
175+
return ResNet(Bottleneck, [3, 4, 23, 3])
176+
177+
178+
def resnet152():
179+
return ResNet(Bottleneck, [3, 8, 36, 3])

torchvision/models/vgg.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import torch.nn as nn
2+
3+
4+
__all__ = [
5+
'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
6+
'vgg19_bn', 'vgg19',
7+
]
8+
9+
10+
class VGG(nn.Container):
11+
def __init__(self, features):
12+
super(VGG, self).__init__()
13+
self.features = features
14+
self.classifier = nn.Sequential(
15+
nn.Dropout(),
16+
nn.Linear(512 * 7 * 7, 4096),
17+
nn.ReLU(True),
18+
nn.Dropout(),
19+
nn.Linear(4096, 4096),
20+
nn.ReLU(True),
21+
nn.Linear(4096, 1000),
22+
)
23+
24+
def forward(self, x):
25+
x = self.features(x)
26+
x = x.view(x.size(0), -1)
27+
x = self.classifier(x)
28+
return x
29+
30+
31+
def make_layers(cfg, batch_norm=False):
32+
layers = []
33+
in_channels = 3
34+
for v in cfg:
35+
if v == 'M':
36+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
37+
else:
38+
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
39+
if batch_norm:
40+
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
41+
else:
42+
layers += [conv2d, nn.ReLU(inplace=True)]
43+
in_channels = v
44+
return nn.Sequential(*layers)
45+
46+
47+
cfg = {
48+
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
49+
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
50+
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
51+
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
52+
}
53+
54+
55+
def vgg11():
56+
return VGG(make_layers(cfg['A']))
57+
58+
59+
def vgg11_bn():
60+
return VGG(make_layers(cfg['A'], batch_norm=True))
61+
62+
63+
def vgg13():
64+
return VGG(make_layers(cfg['B']))
65+
66+
67+
def vgg13_bn():
68+
return VGG(make_layers(cfg['B'], batch_norm=True))
69+
70+
71+
def vgg16():
72+
return VGG(make_layers(cfg['D']))
73+
74+
75+
def vgg16_bn():
76+
return VGG(make_layers(cfg['D'], batch_norm=True))
77+
78+
79+
def vgg19():
80+
return VGG(make_layers(cfg['E']))
81+
82+
83+
def vgg19_bn():
84+
return VGG(make_layers(cfg['E'], batch_norm=True))

0 commit comments

Comments
 (0)