From 619665ca0ee92731ce59f60031b447469cb3ded1 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 7 Mar 2019 16:55:21 +0300 Subject: [PATCH 1/2] Linknet model --- .../linknet/__init__.py | 1 + .../linknet/decoder.py | 75 +++++++++++++++++++ segmentation_models_pytorch/linknet/model.py | 31 ++++++++ 3 files changed, 107 insertions(+) create mode 100644 segmentation_models_pytorch/linknet/__init__.py create mode 100644 segmentation_models_pytorch/linknet/decoder.py create mode 100644 segmentation_models_pytorch/linknet/model.py diff --git a/segmentation_models_pytorch/linknet/__init__.py b/segmentation_models_pytorch/linknet/__init__.py new file mode 100644 index 00000000..6a26d572 --- /dev/null +++ b/segmentation_models_pytorch/linknet/__init__.py @@ -0,0 +1 @@ +from .model import Linknet \ No newline at end of file diff --git a/segmentation_models_pytorch/linknet/decoder.py b/segmentation_models_pytorch/linknet/decoder.py new file mode 100644 index 00000000..3d93685b --- /dev/null +++ b/segmentation_models_pytorch/linknet/decoder.py @@ -0,0 +1,75 @@ +import torch.nn as nn + +from ..common.blocks import Conv2dReLU +from ..base.model import Model + + +class TransposeX2(nn.Module): + + def __init__(self, in_channels, out_channels, use_batchnorm=True, **batchnorm_params): + super().__init__() + layers = [ + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1), + nn.ReLU(inplace=True), + ] + if use_batchnorm: + layers.insert(1, nn.BatchNorm2d(out_channels, **batchnorm_params)) + + self.block = nn.Sequential(*layers) + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, use_batchnorm=True): + super().__init__() + + self.block = nn.Sequential( + Conv2dReLU(in_channels, in_channels // 4, kernel_size=1, use_batchnorm=use_batchnorm), + TransposeX2(in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm), + Conv2dReLU(in_channels // 4, out_channels, kernel_size=1, use_batchnorm=use_batchnorm), + ) + + def forward(self, x): + x, skip = x + x = self.block(x) + if skip is not None: + x += skip + return x + + +class LinknetDecoder(Model): + + def __init__( + self, + encoder_channels, + prefinal_channels=32, + final_channels=1, + use_batchnorm=True, + ): + super().__init__() + + in_channels = encoder_channels + + self.layer1 = DecoderBlock(in_channels[0], in_channels[1], use_batchnorm=use_batchnorm) + self.layer2 = DecoderBlock(in_channels[1], in_channels[2], use_batchnorm=use_batchnorm) + self.layer3 = DecoderBlock(in_channels[2], in_channels[3], use_batchnorm=use_batchnorm) + self.layer4 = DecoderBlock(in_channels[3], in_channels[4], use_batchnorm=use_batchnorm) + self.layer5 = DecoderBlock(in_channels[4], prefinal_channels, use_batchnorm=use_batchnorm) + self.final_conv = nn.Conv2d(prefinal_channels, final_channels, kernel_size=(1, 1)) + + self.initialize() + + def forward(self, x): + encoder_head = x[0] + skips = x[1:] + + x = self.layer1([encoder_head, skips[0]]) + x = self.layer2([x, skips[1]]) + x = self.layer3([x, skips[2]]) + x = self.layer4([x, skips[3]]) + x = self.layer5([x, None]) + x = self.final_conv(x) + + return x diff --git a/segmentation_models_pytorch/linknet/model.py b/segmentation_models_pytorch/linknet/model.py new file mode 100644 index 00000000..fd8f0d89 --- /dev/null +++ b/segmentation_models_pytorch/linknet/model.py @@ -0,0 +1,31 @@ +from .decoder import LinknetDecoder +from ..base import EncoderDecoder +from ..encoders import get_encoder + + +class Linknet(EncoderDecoder): + + def __init__( + self, + encoder_name='resnet34', + encoder_weights='imagenet', + decoder_use_batchnorm=True, + classes=1, + activation='sigmoid', + ): + + encoder = get_encoder( + encoder_name, + encoder_weights=encoder_weights + ) + + decoder = LinknetDecoder( + encoder_channels=encoder.out_shapes, + prefinal_channels=32, + final_channels=classes, + use_batchnorm=decoder_use_batchnorm, + ) + + super().__init__(encoder, decoder, activation) + + self.name = 'link-{}'.format(encoder_name) From da8bce64c30b12471fcd938bd7c740fcda7743c7 Mon Sep 17 00:00:00 2001 From: qubvel Date: Thu, 7 Mar 2019 16:55:27 +0300 Subject: [PATCH 2/2] Linknet model --- segmentation_models_pytorch/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index 3581d959..2bc1693b 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -1,3 +1,4 @@ from .unet import Unet +from .linknet import Linknet from . import encoders \ No newline at end of file