diff --git a/docs/source/models.rst b/docs/source/models.rst index f27a555befe..4422e3776f9 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -24,6 +24,7 @@ architectures for image classification: - `ShuffleNet`_ v2 - `MobileNet`_ v2 - `ResNeXt`_ +- `UNet`_ You can construct a model with random weights by calling its constructor: @@ -40,6 +41,7 @@ You can construct a model with random weights by calling its constructor: shufflenet = models.shufflenet_v2_x1_0() mobilenet = models.mobilenet_v2() resnext50_32x4d = models.resnext50_32x4d() + unet = model.unet23() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -65,7 +67,7 @@ This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See Some models use modules which have different training and evaluation behavior, such as batch normalization. To switch between these modes, use ``model.train()`` or ``model.eval()`` as appropriate. See -:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details. +:meth:`~torch.nn.Module.train` or :meth:`~torch.nn.Module.eval` for details. All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), @@ -124,6 +126,7 @@ ResNeXt-101-32x8d 20.69 5.47 .. _ShuffleNet: https://arxiv.org/abs/1807.11164 .. _MobileNet: https://arxiv.org/abs/1801.04381 .. _ResNeXt: https://arxiv.org/abs/1611.05431 +.. _UNet: https://arxiv.org/abs/1505.04597 .. currentmodule:: torchvision.models @@ -359,3 +362,12 @@ Keypoint R-CNN .. autofunction:: torchvision.models.detection.keypointrcnn_resnet50_fpn +UNet +____ + +.. autofunction:: unet8 +.. autofunction:: unet13 +.. autofunction:: unet18 +.. autofunction:: unet23 +.. autofunction:: unet28 +.. autofunction:: unet33 diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 7f460999296..485e78402a6 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -7,5 +7,6 @@ from .googlenet import * from .mobilenet import * from .shufflenetv2 import * +from .unet import * from . import segmentation from . import detection diff --git a/torchvision/models/unet.py b/torchvision/models/unet.py new file mode 100644 index 00000000000..f6cc188719a --- /dev/null +++ b/torchvision/models/unet.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn + + +__all__ = ['UNet', 'unet8', 'unet13', 'unet18', 'unet23', 'unet28', 'unet33'] + + +def double_conv(in_channels, out_channels): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3), + nn.ReLU(inplace=True), + nn.Conv2d(out_channels, out_channels, 3), + nn.ReLU(inplace=True) + ) + + +def center_crop(img, output_size): + _, _, h, w = img.size() + _, _, th, tw = output_size + i = (h - th) // 2 + j = (w - tw) // 2 + return img[:, :, i:i + th, j:j + tw] + + +class Contract(nn.Module): + + def __init__(self, in_channels, out_channels, dropout=False, p=0.5): + super(Contract, self).__init__() + assert in_channels < out_channels + + self.pool = nn.MaxPool2d(2) + self.conv = double_conv(in_channels, out_channels) + self.drop = None + + if dropout: + self.drop = nn.Dropout2d(p=p) + + def forward(self, x): + x = self.pool(x) + x = self.conv(x) + + if self.drop is not None: + x = self.drop(x) + + return x + + +class Expand(nn.Module): + + def __init__(self, in_channels, out_channels): + super(Expand, self).__init__() + assert in_channels > out_channels + + self.upconv = nn.ConvTranspose2d(in_channels, out_channels, 2, 2) + self.relu = nn.ReLU(inplace=True) + self.conv = double_conv(in_channels, out_channels) + + def forward(self, x, out): + x = self.upconv(x) + x = self.relu(x) + + out = center_crop(out, x.size()) + x = torch.cat([out, x], 1) + + x = self.conv(x) + + return x + + +class UNet(nn.Module): + """`U-Net `_ architecture. + + Args: + in_channels (int, optional): number of channels in input image + num_classes (int, optional): number of classes in output segmentation + start_channels (int, optional): power of 2 channels to start with + depth (int, optional): number of contractions/expansions + p (float, optional): dropout probability + """ + + def __init__(self, in_channels=1, num_classes=2, start_channels=6, + depth=4, p=0.5): + super(UNet, self).__init__() + + self.depth = depth + + # Contraction + self.conv1 = double_conv(in_channels, 2 ** start_channels) + self.contractions = nn.ModuleList([ + Contract(2 ** d, 2 ** (d + 1), dropout=d - depth > 3, p=p) + for d in range(start_channels, start_channels + depth) + ]) + + # Expansion + self.expansions = nn.ModuleList([ + Expand(2 ** d, 2 ** (d - 1)) for d in range( + start_channels + depth, start_channels, -1) + ]) + self.conv2 = nn.Conv2d(2 ** start_channels, num_classes, 1) + self.softmax = nn.LogSoftmax(dim=1) + + # Initialize weights + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_normal_(m.weight, nonlinearity='relu') + nn.init.constant_(m.bias, 0) + + def forward(self, x): + # Contraction + out = [self.conv1(x)] + for f in self.contractions: + out.append(f(out[-1])) + + # Expansion + i = -2 + x = out[-1] + for f in self.expansions: + x = f(x, out[i]) + i -= 1 + + x = self.conv2(x) + x = self.softmax(x) + + return x + + +def unet8(**kwargs): + """Constructs a U-Net 8 model.""" + return UNet(depth=1, **kwargs) + + +def unet13(**kwargs): + """Constructs a U-Net 13 model.""" + return UNet(depth=2, **kwargs) + + +def unet18(**kwargs): + """Constructs a U-Net 18 model.""" + return UNet(depth=3, **kwargs) + + +def unet23(**kwargs): + """Constructs a U-Net 23 model.""" + return UNet(depth=4, **kwargs) + + +def unet28(**kwargs): + """Constructs a U-Net 28 model.""" + return UNet(depth=5, **kwargs) + + +def unet33(**kwargs): + """Constructs a U-Net 33 model.""" + return UNet(depth=6, **kwargs)