Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
#### Architectures <a name="architectires"></a>
- Unet [[paper](https://arxiv.org/abs/1505.04597)] [[docs](https://smp.readthedocs.io/en/latest/models.html#unet)]
- Unet++ [[paper](https://arxiv.org/pdf/1807.10165.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id2)]
- MAnet [[paper](https://ieeexplore.ieee.org/abstract/document/9201310)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id2)]
- Linknet [[paper](https://arxiv.org/abs/1707.03718)] [[docs](https://smp.readthedocs.io/en/latest/models.html#linknet)]
- FPN [[paper](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#fpn)]
- PSPNet [[paper](https://arxiv.org/abs/1612.01105)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pspnet)]
Expand Down
4 changes: 4 additions & 0 deletions docs/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ Unet++
~~~~~~
.. autoclass:: segmentation_models_pytorch.UnetPlusPlus

MAnet
~~~~~~
.. autoclass:: segmentation_models_pytorch.MAnet

Linknet
~~~~~~~
.. autoclass:: segmentation_models_pytorch.Linknet
Expand Down
5 changes: 3 additions & 2 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .unet import Unet
from .unetplusplus import UnetPlusPlus
from .manet import MAnet
from .linknet import Linknet
from .fpn import FPN
from .pspnet import PSPNet
Expand All @@ -24,9 +25,9 @@ def create_model(
**kwargs,
) -> torch.nn.Module:
"""Models wrapper. Allows to create any model just with parametes

"""

archs = [Unet, UnetPlusPlus, Linknet, FPN, PSPNet, DeepLabV3, DeepLabV3Plus, PAN]
archs_dict = {a.__name__.lower(): a for a in archs}
try:
Expand Down
1 change: 1 addition & 0 deletions segmentation_models_pytorch/manet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model import MAnet
188 changes: 188 additions & 0 deletions segmentation_models_pytorch/manet/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..base import modules as md


class PAB(nn.Module):
def __init__(self, in_channels, out_channels, im_channels=64):
super(PAB, self).__init__()
# Series of 1x1 conv to generate attention feature maps
self.im_channels = im_channels
self.in_channels = in_channels
self.top_conv = nn.Conv2d(in_channels, im_channels, kernel_size=1)
self.center_conv = nn.Conv2d(in_channels, im_channels, kernel_size=1)
self.bottom_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.map_softmax = nn.Softmax(dim=1)
self.out_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)

def forward(self, x):
bsize = x.size()[0]
h = x.size()[2]
w = x.size()[3]
x_top = self.top_conv(x)
x_center = self.center_conv(x)
x_bottom = self.bottom_conv(x)

x_top = x_top.flatten(2)
x_center = x_center.flatten(2).transpose(1, 2)
x_bottom = x_bottom.flatten(2).transpose(1, 2)

sp_map = torch.matmul(x_center, x_top)
sp_map = self.map_softmax(sp_map.view(bsize, -1)).view(bsize, h*w, h*w)
sp_map = torch.matmul(sp_map, x_bottom)
sp_map = sp_map.reshape(bsize, self.in_channels, h, w)
x = x + sp_map
x = self.out_conv(x)
return x


class MFAB(nn.Module):
def __init__(self, in_channels, skip_channels, out_channels, use_batchnorm=True, reduction=16):
# MFAB is just a modified version of SE-blocks, one for skip, one for input
super(MFAB, self).__init__()
self.hl_conv = nn.Sequential(
md.Conv2dReLU(
in_channels,
in_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
),
md.Conv2dReLU(
in_channels,
skip_channels,
kernel_size=1,
use_batchnorm=use_batchnorm,
)
)
self.SE_ll = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
nn.Sigmoid(),
)
self.SE_hl = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(skip_channels, skip_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(skip_channels // reduction, skip_channels, 1),
nn.Sigmoid(),
)
self.conv1 = md.Conv2dReLU(
skip_channels + skip_channels, # we transform C-prime form high level to C from skip connection
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)

def forward(self, x, skip=None):
x = self.hl_conv(x)
x = F.interpolate(x, scale_factor=2, mode="nearest")
attention_hl = self.SE_hl(x)
if skip is not None:
attention_ll = self.SE_ll(skip)
attention_hl = attention_hl + attention_ll
x = x * attention_hl
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x


class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
skip_channels,
out_channels,
use_batchnorm=True
):
super().__init__()
self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = md.Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)

def forward(self, x, skip=None):
x = F.interpolate(x, scale_factor=2, mode="nearest")
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x


class MAnetDecoder(nn.Module):
def __init__(
self,
encoder_channels,
decoder_channels,
n_blocks=5,
reduction=16,
use_batchnorm=True,
im_channels=64
):
super().__init__()

if n_blocks != len(decoder_channels):
raise ValueError(
"Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
n_blocks, len(decoder_channels)
)
)

encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution
encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder

# computing blocks input and output channels
head_channels = encoder_channels[0]
in_channels = [head_channels] + list(decoder_channels[:-1])
skip_channels = list(encoder_channels[1:]) + [0]
out_channels = decoder_channels

self.center = PAB(head_channels, head_channels, im_channels=im_channels)

# combine decoder keyword arguments
kwargs = dict(use_batchnorm=use_batchnorm) # no attention type here
blocks = [
MFAB(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs) if skip_ch > 0 else
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
]
# for the last we dont have skip connection -> use simple decoder block
self.blocks = nn.ModuleList(blocks)

def forward(self, *features):

features = features[1:] # remove first skip with same spatial resolution
features = features[::-1] # reverse channels to start from head of encoder

head = features[0]
skips = features[1:]

x = self.center(head)
for i, decoder_block in enumerate(self.blocks):
skip = skips[i] if i < len(skips) else None
x = decoder_block(x, skip)

return x
92 changes: 92 additions & 0 deletions segmentation_models_pytorch/manet/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Optional, Union, List
from .decoder import MAnetDecoder
from ..encoders import get_encoder
from ..base import SegmentationModel
from ..base import SegmentationHead, ClassificationHead


class MAnet(SegmentationModel):
"""MAnet_ : Multi-scale Attention Net.
The MA-Net can capture rich contextual dependencies based on the attention mechanism, using two blocks:
Position-wise Attention Block (PAB, which captures the spatial dependencies between pixels in a global view)
and Multi-scale Fusion Attention Block (MFAB, which captures the channel dependencies between any feature map by
multi-scale semantic feature fusion)
Args:
encoder_name: name of classification model (without last dense layers) used as feature
extractor to build segmentation model.
encoder_depth (int): number of stages used in decoder, larger depth - more features are generated.
e.g. for depth=3 encoder will generate list of features with following spatial shapes
[(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature tensor will have
spatial resolution (H/(2^depth), W/(2^depth)]
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks
decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption.
One of [True, False, 'inplace']
decoder_im_channels: number of layers for PAB layer.
in_channels: number of input channels for model, default is 3.
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
activation: activation function to apply after final convolution;
One of [``sigmoid``, ``softmax``, ``logsoftmax``, ``identity``, callable, None]
aux_params: if specified model will have additional classification auxiliary output
build on top of encoder, supported params:
- classes (int): number of classes
- pooling (str): one of 'max', 'avg'. Default is 'avg'.
- dropout (float): dropout factor in [0, 1)
- activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits)
Returns:
``torch.nn.Module``: **MAnet**
.. _MAnet:
https://ieeexplore.ieee.org/abstract/document/9201310
"""

def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: str = "imagenet",
decoder_use_batchnorm: bool = True,
decoder_channels: List[int] = (256, 128, 64, 32, 16),
decoder_im_channels: int = 64,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, callable]] = None,
aux_params: Optional[dict] = None
):
super().__init__()

self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
)

self.decoder = MAnetDecoder(
encoder_channels=self.encoder.out_channels,
decoder_channels=decoder_channels,
n_blocks=encoder_depth,
use_batchnorm=decoder_use_batchnorm,
im_channels=decoder_im_channels
)

self.segmentation_head = SegmentationHead(
in_channels=decoder_channels[-1],
out_channels=classes,
activation=activation,
kernel_size=3,
)

if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None

self.name = "manet-{}".format(encoder_name)
self.initialize()
8 changes: 4 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_encoders():


def get_sample(model_class):
if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus]:
if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet, smp.UnetPlusPlus, smp.MAnet]:
sample = torch.ones([1, 3, 64, 64])
elif model_class == smp.PAN:
sample = torch.ones([2, 3, 256, 256])
Expand Down Expand Up @@ -58,7 +58,7 @@ def _test_forward_backward(model, sample, test_shape=False):
@pytest.mark.parametrize("encoder_depth", [3, 5])
@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus])
def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
if model_class is smp.Unet or model_class is smp.UnetPlusPlus:
if model_class is smp.Unet or model_class is smp.UnetPlusPlus or model_class is smp.MAnet:
kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:]
model = model_class(
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
Expand All @@ -75,15 +75,15 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):

@pytest.mark.parametrize(
"model_class",
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.DeepLabV3]
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet, smp.DeepLabV3]
)
def test_forward_backward(model_class):
sample = get_sample(model_class)
model = model_class(DEFAULT_ENCODER, encoder_weights=None)
_test_forward_backward(model, sample)


@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus])
@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.UnetPlusPlus, smp.MAnet])
def test_aux_output(model_class):
model = model_class(
DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2)
Expand Down