Skip to content

Add U-Net model #899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ architectures for image classification:
- `ShuffleNet`_ v2
- `MobileNet`_ v2
- `ResNeXt`_
- `UNet`_

You can construct a model with random weights by calling its constructor:

Expand All @@ -40,6 +41,7 @@ You can construct a model with random weights by calling its constructor:
shufflenet = models.shufflenet_v2_x1_0()
mobilenet = models.mobilenet_v2()
resnext50_32x4d = models.resnext50_32x4d()
unet = model.unet23()

We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:
Expand All @@ -65,7 +67,7 @@ This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See
Some models use modules which have different training and evaluation
behavior, such as batch normalization. To switch between these modes, use
``model.train()`` or ``model.eval()`` as appropriate. See
:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.
:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details.

All pre-trained models expect input images normalized in the same way,
i.e. mini-batches of 3-channel RGB images of shape (3 x H x W),
Expand Down Expand Up @@ -124,6 +126,7 @@ ResNeXt-101-32x8d 20.69 5.47
.. _ShuffleNet: https://arxiv.org/abs/1807.11164
.. _MobileNet: https://arxiv.org/abs/1801.04381
.. _ResNeXt: https://arxiv.org/abs/1611.05431
.. _UNet: https://arxiv.org/abs/1505.04597

.. currentmodule:: torchvision.models

Expand Down Expand Up @@ -359,3 +362,12 @@ Keypoint R-CNN

.. autofunction:: torchvision.models.detection.keypointrcnn_resnet50_fpn

UNet
____

.. autofunction:: unet8
.. autofunction:: unet13
.. autofunction:: unet18
.. autofunction:: unet23
.. autofunction:: unet28
.. autofunction:: unet33
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
from .googlenet import *
from .mobilenet import *
from .shufflenetv2 import *
from .unet import *
from . import segmentation
from . import detection
154 changes: 154 additions & 0 deletions torchvision/models/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import torch
import torch.nn as nn


__all__ = ['UNet', 'unet8', 'unet13', 'unet18', 'unet23', 'unet28', 'unet33']


def double_conv(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3),
nn.ReLU(inplace=True)
)


def center_crop(img, output_size):
_, _, h, w = img.size()
_, _, th, tw = output_size
i = (h - th) // 2
j = (w - tw) // 2
return img[:, :, i:i + th, j:j + tw]


class Contract(nn.Module):

def __init__(self, in_channels, out_channels, dropout=False, p=0.5):
super(Contract, self).__init__()
assert in_channels < out_channels

self.pool = nn.MaxPool2d(2)
self.conv = double_conv(in_channels, out_channels)
self.drop = None

if dropout:
self.drop = nn.Dropout2d(p=p)

def forward(self, x):
x = self.pool(x)
x = self.conv(x)

if self.drop is not None:
x = self.drop(x)

return x


class Expand(nn.Module):

def __init__(self, in_channels, out_channels):
super(Expand, self).__init__()
assert in_channels > out_channels

self.upconv = nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
self.relu = nn.ReLU(inplace=True)
self.conv = double_conv(in_channels, out_channels)

def forward(self, x, out):
x = self.upconv(x)
x = self.relu(x)

out = center_crop(out, x.size())
x = torch.cat([out, x], 1)

x = self.conv(x)

return x


class UNet(nn.Module):
"""`U-Net <https://arxiv.org/pdf/1505.04597.pdf>`_ architecture.

Args:
in_channels (int, optional): number of channels in input image
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the reason I chose in_channels=1 as the default for U-Net is because this is how the original U-Net paper is modeled, using a single channel grayscale microscope imagery dataset (see #900). The application I needed it for was actually 4-channel microscope imagery, but unfortunately PIL doesn't support this (see #882). If we decide to pretrain this on COCO/Pascal I'm fine with switching the default to 3-channel.

num_classes (int, optional): number of classes in output segmentation
start_channels (int, optional): power of 2 channels to start with
depth (int, optional): number of contractions/expansions
p (float, optional): dropout probability
"""

def __init__(self, in_channels=1, num_classes=2, start_channels=6,
depth=4, p=0.5):
super(UNet, self).__init__()

self.depth = depth

# Contraction
self.conv1 = double_conv(in_channels, 2 ** start_channels)
self.contractions = nn.ModuleList([
Contract(2 ** d, 2 ** (d + 1), dropout=d - depth > 3, p=p)
for d in range(start_channels, start_channels + depth)
])

# Expansion
self.expansions = nn.ModuleList([
Expand(2 ** d, 2 ** (d - 1)) for d in range(
start_channels + depth, start_channels, -1)
])
self.conv2 = nn.Conv2d(2 ** start_channels, num_classes, 1)
self.softmax = nn.LogSoftmax(dim=1)

# Initialize weights
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
nn.init.constant_(m.bias, 0)

def forward(self, x):
# Contraction
out = [self.conv1(x)]
for f in self.contractions:
out.append(f(out[-1]))

# Expansion
i = -2
x = out[-1]
for f in self.expansions:
x = f(x, out[i])
i -= 1

x = self.conv2(x)
x = self.softmax(x)

return x


def unet8(**kwargs):
"""Constructs a U-Net 8 model."""
return UNet(depth=1, **kwargs)


def unet13(**kwargs):
"""Constructs a U-Net 13 model."""
return UNet(depth=2, **kwargs)


def unet18(**kwargs):
"""Constructs a U-Net 18 model."""
return UNet(depth=3, **kwargs)


def unet23(**kwargs):
"""Constructs a U-Net 23 model."""
return UNet(depth=4, **kwargs)


def unet28(**kwargs):
"""Constructs a U-Net 28 model."""
return UNet(depth=5, **kwargs)


def unet33(**kwargs):
"""Constructs a U-Net 33 model."""
return UNet(depth=6, **kwargs)