Skip to content

Add inceptionv4 pretrained model #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- `AlexNet`_
- `VGG`_
- `ResNet`_
_ `InceptionV4`_

You can construct a model with random weights by calling its constructor:

Expand All @@ -12,6 +13,7 @@
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
inceptionv4 = models.inceptionv4()

We provide pre-trained models for the ResNet variants and AlexNet, using the
PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
Expand All @@ -22,12 +24,15 @@
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
inceptionv4 = models.inceptionv4(pretrained=True)

.. _AlexNet: https://arxiv.org/abs/1404.5997
.. _VGG: https://arxiv.org/abs/1409.1556
.. _ResNet: https://arxiv.org/abs/1512.03385
.. _InceptionV4: https://arxiv.org/abs/1602.07261
"""

from .alexnet import *
from .resnet import *
from .vgg import *
from .inceptionv4 import *
272 changes: 272 additions & 0 deletions torchvision/models/inceptionv4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
import torch.nn as nn
import torch.utils.model_zoo as model_zoo

__all__ = ['InceptionV4', 'inceptionv4']

model_urls = {
'inceptionv4': 'https://s3.amazonaws.com/pytorch/models/inceptionv4-58153ba9.pth'
}

class BasicConv2d(nn.Module):

def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) # verify bias false
self.bn = nn.BatchNorm2d(out_planes, eps=0.001, momentum=0, affine=True)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x

class Mixed_3a(nn.Module):

def __init__(self):
super(Mixed_3a, self).__init__()
self.maxpool = nn.MaxPool2d(3, stride=2)
self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2)

def forward(self, x):
x0 = self.maxpool(x)
x1 = self.conv(x)
out = torch.cat((x0, x1), 1)
return out

class Mixed_4a(nn.Module):

def __init__(self):
super(Mixed_4a, self).__init__()

self.block0 = nn.Sequential(
BasicConv2d(160, 64, kernel_size=1, stride=1),
BasicConv2d(64, 96, kernel_size=3, stride=1)
)

self.block1 = nn.Sequential(
BasicConv2d(160, 64, kernel_size=1, stride=1),
BasicConv2d(64, 64, kernel_size=(1,7), stride=1, padding=(0,3)),
BasicConv2d(64, 64, kernel_size=(7,1), stride=1, padding=(3,0)),
BasicConv2d(64, 96, kernel_size=(3,3), stride=1)
)

def forward(self, x):
x0 = self.block0(x)
x1 = self.block1(x)
out = torch.cat((x0, x1), 1)
return out

class Mixed_5a(nn.Module):

def __init__(self):
super(Mixed_5a, self).__init__()
self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2)
self.maxpool = nn.MaxPool2d(3, stride=2)

def forward(self, x):
x0 = self.conv(x)
x1 = self.maxpool(x)
out = torch.cat((x0, x1), 1)
return out

class Inception_A(nn.Module):

def __init__(self):
super(Inception_A, self).__init__()
self.block0 = BasicConv2d(384, 96, kernel_size=1, stride=1)

self.block1 = nn.Sequential(
BasicConv2d(384, 64, kernel_size=1, stride=1),
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1)
)

self.block2 = nn.Sequential(
BasicConv2d(384, 64, kernel_size=1, stride=1),
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
)

self.block3 = nn.Sequential(
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
BasicConv2d(384, 96, kernel_size=1, stride=1)
)

def forward(self, x):
x0 = self.block0(x)
x1 = self.block1(x)
x2 = self.block2(x)
x3 = self.block3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out

class Reduction_A(nn.Module):

def __init__(self):
super(Reduction_A, self).__init__()
self.block0 = BasicConv2d(384, 384, kernel_size=3, stride=2)

self.block1 = nn.Sequential(
BasicConv2d(384, 192, kernel_size=1, stride=1),
BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1),
BasicConv2d(224, 256, kernel_size=3, stride=2)
)

self.block2 = nn.MaxPool2d(3, stride=2)

def forward(self, x):
x0 = self.block0(x)
x1 = self.block1(x)
x2 = self.block2(x)
out = torch.cat((x0, x1, x2), 1)
return out

class Inception_B(nn.Module):

def __init__(self):
super(Inception_B, self).__init__()
self.block0 = BasicConv2d(1024, 384, kernel_size=1, stride=1)

self.block1 = nn.Sequential(
BasicConv2d(1024, 192, kernel_size=1, stride=1),
BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)),
BasicConv2d(224, 256, kernel_size=(7,1), stride=1, padding=(3,0))
)

self.block2 = nn.Sequential(
BasicConv2d(1024, 192, kernel_size=1, stride=1),
BasicConv2d(192, 192, kernel_size=(7,1), stride=1, padding=(3,0)),
BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)),
BasicConv2d(224, 224, kernel_size=(7,1), stride=1, padding=(3,0)),
BasicConv2d(224, 256, kernel_size=(1,7), stride=1, padding=(0,3))
)

self.block3 = nn.Sequential(
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
BasicConv2d(1024, 128, kernel_size=1, stride=1)
)

def forward(self, x):
x0 = self.block0(x)
x1 = self.block1(x)
x2 = self.block2(x)
x3 = self.block3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out

class Reduction_B(nn.Module):

def __init__(self):
super(Reduction_B, self).__init__()

self.block0 = nn.Sequential(
BasicConv2d(1024, 192, kernel_size=1, stride=1),
BasicConv2d(192, 192, kernel_size=3, stride=2)
)

self.block1 = nn.Sequential(
BasicConv2d(1024, 256, kernel_size=1, stride=1),
BasicConv2d(256, 256, kernel_size=(1,7), stride=1, padding=(0,3)),
BasicConv2d(256, 320, kernel_size=(7,1), stride=1, padding=(3,0)),
BasicConv2d(320, 320, kernel_size=3, stride=2)
)

self.block2 = nn.MaxPool2d(3, stride=2)

def forward(self, x):
x0 = self.block0(x)
x1 = self.block1(x)
x2 = self.block2(x)
out = torch.cat((x0, x1, x2), 1)
return out

class Inception_C(nn.Module):

def __init__(self):
super(Inception_C, self).__init__()
self.block0 = BasicConv2d(1536, 256, kernel_size=1, stride=1)

self.block1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
self.block1_1a = BasicConv2d(384, 256, kernel_size=(1,3), stride=1, padding=(0,1))
self.block1_1b = BasicConv2d(384, 256, kernel_size=(3,1), stride=1, padding=(1,0))

self.block2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
self.block2_1 = BasicConv2d(384, 448, kernel_size=(3,1), stride=1, padding=(1,0))
self.block2_2 = BasicConv2d(448, 512, kernel_size=(1,3), stride=1, padding=(0,1))
self.block2_3a = BasicConv2d(512, 256, kernel_size=(1,3), stride=1, padding=(0,1))
self.block2_3b = BasicConv2d(512, 256, kernel_size=(3,1), stride=1, padding=(1,0))

self.block3 = nn.Sequential(
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
BasicConv2d(1536, 256, kernel_size=1, stride=1)
)

def forward(self, x):
x0 = self.block0(x)

x1_0 = self.block1_0(x)
x1_1a = self.block1_1a(x1_0)
x1_1b = self.block1_1b(x1_0)
x1 = torch.cat((x1_1a, x1_1b), 1)

x2_0 = self.block2_0(x)
x2_1 = self.block2_1(x2_0)
x2_2 = self.block2_2(x2_1)
x2_3a = self.block2_3a(x2_2)
x2_3b = self.block2_3b(x2_2)
x2 = torch.cat((x2_3a, x2_3b), 1)

x3 = self.block3(x)

out = torch.cat((x0, x1, x2, x3), 1)
return out

class InceptionV4(nn.Module):

def __init__(self, num_classes=1001):
super(InceptionV4, self).__init__()
self.features = nn.Sequential(
BasicConv2d(3, 32, kernel_size=3, stride=2),
BasicConv2d(32, 32, kernel_size=3, stride=1),
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
Mixed_3a(),
Mixed_4a(),
Mixed_5a(),
Inception_A(),
Inception_A(),
Inception_A(),
Inception_A(),
Reduction_A(), # Mixed_6a
Inception_B(),
Inception_B(),
Inception_B(),
Inception_B(),
Inception_B(),
Inception_B(),
Inception_B(),
Reduction_B(), # Mixed_7a
Inception_C(),
Inception_C(),
Inception_C(),
nn.AvgPool2d(8, count_include_pad=False)
)
self.classif = nn.Linear(1536, num_classes)

def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classif(x)
return x


def inceptionv4(pretrained=False):
r"""InceptionV4 model architecture from the
`"Inception-v4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = InceptionV4()
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['inceptionv4']))
return model